use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use http::Request;
use http::Response;
use http::Uri;
use super::{BoxFuture, ClientBody, ClientTransport};
use crate::error::ConnectError;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[cfg(feature = "client-tls")]
use std::sync::Arc;
trait H2Io: hyper::rt::Read + hyper::rt::Write + Send + Unpin {}
impl<T: hyper::rt::Read + hyper::rt::Write + Send + Unpin> H2Io for T {}
type BoxedIo = Pin<Box<dyn H2Io>>;
type BoxedConnector = tower::util::BoxService<Uri, BoxedIo, BoxError>;
fn box_connector<C>(connector: C) -> BoxedConnector
where
C: tower::Service<Uri> + Send + 'static,
C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
C::Error: Into<BoxError>,
C::Future: Send + 'static,
{
use tower::ServiceExt;
tower::util::BoxService::new(
connector
.map_response(|io| Box::pin(io) as BoxedIo)
.map_err(Into::into),
)
}
#[cfg(unix)]
fn unix_connector(
path: std::path::PathBuf,
) -> impl tower::Service<
Uri,
Response = hyper_util::rt::TokioIo<tokio::net::UnixStream>,
Error = ConnectError,
Future: Send + 'static,
> + Send
+ 'static {
tower::service_fn(move |_uri: Uri| {
let path = path.clone();
async move {
let stream = tokio::net::UnixStream::connect(&path).await.map_err(|e| {
ConnectError::unavailable(format!(
"unix socket connect to {} failed: {e}",
path.display()
))
})?;
Ok(hyper_util::rt::TokioIo::new(stream))
}
})
}
#[cfg(feature = "client-tls")]
fn prepare_tls_for_h2(config: &Arc<rustls::ClientConfig>) -> Arc<rustls::ClientConfig> {
let mut cfg = (**config).clone();
cfg.alpn_protocols = vec![b"h2".to_vec()];
Arc::new(cfg)
}
#[cfg(feature = "client-tls")]
fn server_name_from_uri(uri: &Uri) -> Result<rustls_pki_types::ServerName<'static>, ConnectError> {
let host = uri.host().ok_or_else(|| {
ConnectError::invalid_argument("URI must have a host for TLS server name resolution")
})?;
let stripped = host
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
.unwrap_or(host);
rustls_pki_types::ServerName::try_from(stripped.to_owned()).map_err(|e| {
ConnectError::invalid_argument(format!("invalid TLS server name '{host}': {e}"))
})
}
#[cfg(feature = "client-tls")]
fn require_https_scheme(uri: &Uri) -> Result<(), ConnectError> {
match uri.scheme_str() {
Some("https") => Ok(()),
Some("http") | None => Err(ConnectError::invalid_argument(
"Http2Connection TLS constructors require https:// scheme; \
use connect_plaintext/lazy_plaintext for http://",
)),
Some(other) => Err(ConnectError::invalid_argument(format!(
"unsupported URI scheme: {other}"
))),
}
}
pub struct Http2Connection {
inner: Reconnect<MakeSendRequest>,
}
impl std::fmt::Debug for Http2Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = match self.inner.state {
ReconnectState::Idle => "Idle",
ReconnectState::Connecting(_) => "Connecting",
ReconnectState::Connected(_) => "Connected",
};
f.debug_struct("Http2Connection")
.field("uri", &self.inner.uri)
.field("state", &state)
.field("has_connected", &self.inner.has_connected)
.finish()
}
}
fn require_http_scheme(uri: &Uri) -> Result<(), ConnectError> {
match uri.scheme_str() {
Some("http") | None => Ok(()),
Some("https") => Err(ConnectError::invalid_argument(
"Http2Connection plaintext constructors require http:// scheme; \
use connect_tls/lazy_tls for https://",
)),
Some(other) => Err(ConnectError::invalid_argument(format!(
"unsupported URI scheme: {other}"
))),
}
}
impl Http2Connection {
pub fn lazy_plaintext(uri: Uri) -> Self {
Self {
inner: Reconnect::new(MakeSendRequest::new(), uri, true),
}
}
pub async fn connect_plaintext(uri: Uri) -> Result<Self, ConnectError> {
require_http_scheme(&uri)?;
let mut conn = Self {
inner: Reconnect::new(MakeSendRequest::new(), uri, false),
};
std::future::poll_fn(|cx| conn.inner.poll_ready(cx))
.await
.map_err(|e| ConnectError::unavailable(format!("connect failed: {e}")))?;
Ok(conn)
}
pub fn with_builder_plaintext(
uri: Uri,
builder: hyper::client::conn::http2::Builder<hyper_util::rt::TokioExecutor>,
) -> Self {
Self {
inner: Reconnect::new(MakeSendRequest::with_builder(builder), uri, true),
}
}
pub fn lazy_with_connector<C>(connector: C, authority: Uri) -> Self
where
C: tower::Service<Uri> + Send + 'static,
C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
C::Error: Into<BoxError>,
C::Future: Send + 'static,
{
Self {
inner: Reconnect::new(
MakeSendRequest::new_custom(box_connector(connector)),
authority,
true,
),
}
}
pub async fn connect_with_connector<C>(
connector: C,
authority: Uri,
) -> Result<Self, ConnectError>
where
C: tower::Service<Uri> + Send + 'static,
C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
C::Error: Into<BoxError>,
C::Future: Send + 'static,
{
let mut conn = Self {
inner: Reconnect::new(
MakeSendRequest::new_custom(box_connector(connector)),
authority,
false,
),
};
std::future::poll_fn(|cx| conn.inner.poll_ready(cx))
.await
.map_err(|e| ConnectError::unavailable(format!("connect failed: {e}")))?;
Ok(conn)
}
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub fn lazy_unix(path: impl Into<std::path::PathBuf>, authority: Uri) -> Self {
Self::lazy_with_connector(unix_connector(path.into()), authority)
}
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
pub async fn connect_unix(
path: impl Into<std::path::PathBuf>,
authority: Uri,
) -> Result<Self, ConnectError> {
Self::connect_with_connector(unix_connector(path.into()), authority).await
}
#[cfg(feature = "client-tls")]
pub fn lazy_tls(uri: Uri, tls_config: Arc<rustls::ClientConfig>) -> Self {
Self {
inner: Reconnect::new(MakeSendRequest::new_tls(tls_config), uri, true),
}
}
#[cfg(feature = "client-tls")]
pub async fn connect_tls(
uri: Uri,
tls_config: Arc<rustls::ClientConfig>,
) -> Result<Self, ConnectError> {
require_https_scheme(&uri)?;
let mut conn = Self {
inner: Reconnect::new(MakeSendRequest::new_tls(tls_config), uri, false),
};
std::future::poll_fn(|cx| conn.inner.poll_ready(cx))
.await
.map_err(|e| ConnectError::unavailable(format!("TLS connect failed: {e}")))?;
Ok(conn)
}
#[cfg(feature = "client-tls")]
pub fn with_builder_tls(
uri: Uri,
builder: hyper::client::conn::http2::Builder<hyper_util::rt::TokioExecutor>,
tls_config: Arc<rustls::ClientConfig>,
) -> Self {
Self {
inner: Reconnect::new(
MakeSendRequest::with_builder_tls(builder, tls_config),
uri,
true,
),
}
}
}
impl tower::Service<Request<ClientBody>> for Http2Connection {
type Response = Response<hyper::body::Incoming>;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ClientBody>) -> Self::Future {
self.inner.call(req)
}
}
#[derive(Clone)]
#[allow(clippy::type_complexity)] pub struct SharedHttp2Connection {
inner: tower::buffer::Buffer<
Request<ClientBody>,
BoxFuture<'static, Result<Response<hyper::body::Incoming>, BoxError>>,
>,
}
impl Http2Connection {
pub fn shared(self, bound: usize) -> SharedHttp2Connection {
let (buffer, worker) = tower::buffer::Buffer::pair(self, bound);
tokio::spawn(worker);
SharedHttp2Connection { inner: buffer }
}
}
impl tower::Service<Request<ClientBody>> for SharedHttp2Connection {
type Response = Response<hyper::body::Incoming>;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Response<hyper::body::Incoming>, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<_ as tower::Service<Request<ClientBody>>>::poll_ready(&mut self.inner, cx)
}
fn call(&mut self, req: Request<ClientBody>) -> Self::Future {
let fut = <_ as tower::Service<Request<ClientBody>>>::call(&mut self.inner, req);
Box::pin(fut)
}
}
impl ClientTransport for SharedHttp2Connection {
type ResponseBody = hyper::body::Incoming;
type Error = ConnectError;
fn send(
&self,
request: Request<ClientBody>,
) -> BoxFuture<'static, Result<Response<Self::ResponseBody>, Self::Error>> {
use tower::ServiceExt;
let svc = self.clone();
Box::pin(async move {
svc.oneshot(request)
.await
.map_err(|e| ConnectError::unavailable(format!("h2 send failed: {e}")))
})
}
}
struct SendRequest {
inner: hyper::client::conn::http2::SendRequest<ClientBody>,
}
impl tower::Service<Request<ClientBody>> for SendRequest {
type Response = Response<hyper::body::Incoming>;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Request<ClientBody>) -> Self::Future {
let fut = self.inner.send_request(req);
Box::pin(async move { fut.await.map_err(Into::into) })
}
}
struct MakeSendRequest {
connector: hyper_util::client::legacy::connect::HttpConnector,
builder: hyper::client::conn::http2::Builder<hyper_util::rt::TokioExecutor>,
#[cfg(feature = "client-tls")]
tls: Option<Arc<rustls::ClientConfig>>,
custom: Option<BoxedConnector>,
}
impl MakeSendRequest {
fn new() -> Self {
let mut connector = hyper_util::client::legacy::connect::HttpConnector::new();
connector.set_nodelay(true);
let builder =
hyper::client::conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new());
Self {
connector,
builder,
#[cfg(feature = "client-tls")]
tls: None,
custom: None,
}
}
fn with_builder(
builder: hyper::client::conn::http2::Builder<hyper_util::rt::TokioExecutor>,
) -> Self {
let mut connector = hyper_util::client::legacy::connect::HttpConnector::new();
connector.set_nodelay(true);
Self {
connector,
builder,
#[cfg(feature = "client-tls")]
tls: None,
custom: None,
}
}
fn new_custom(conn: BoxedConnector) -> Self {
let connector = hyper_util::client::legacy::connect::HttpConnector::new();
let builder =
hyper::client::conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new());
Self {
connector,
builder,
#[cfg(feature = "client-tls")]
tls: None,
custom: Some(conn),
}
}
#[cfg(feature = "client-tls")]
fn new_tls(tls: Arc<rustls::ClientConfig>) -> Self {
let mut connector = hyper_util::client::legacy::connect::HttpConnector::new();
connector.set_nodelay(true);
connector.enforce_http(false);
let builder =
hyper::client::conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new());
Self {
connector,
builder,
tls: Some(prepare_tls_for_h2(&tls)),
custom: None,
}
}
#[cfg(feature = "client-tls")]
fn with_builder_tls(
builder: hyper::client::conn::http2::Builder<hyper_util::rt::TokioExecutor>,
tls: Arc<rustls::ClientConfig>,
) -> Self {
let mut connector = hyper_util::client::legacy::connect::HttpConnector::new();
connector.set_nodelay(true);
connector.enforce_http(false);
Self {
connector,
builder,
tls: Some(prepare_tls_for_h2(&tls)),
custom: None,
}
}
}
impl tower::Service<Uri> for MakeSendRequest {
type Response = SendRequest;
type Error = BoxError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(c) = &mut self.custom {
return c.poll_ready(cx);
}
<_ as tower::Service<Uri>>::poll_ready(&mut self.connector, cx).map_err(Into::into)
}
fn call(&mut self, uri: Uri) -> Self::Future {
if let Some(c) = &mut self.custom {
let io_fut = c.call(uri);
let builder = self.builder.clone();
return Box::pin(async move {
let io = io_fut.await?;
let (send_request, conn) = builder.handshake(io).await?;
tokio::spawn(async move {
if let Err(e) = conn.await {
tracing::debug!("h2 connection task exited with error: {e}");
}
});
Ok(SendRequest {
inner: send_request,
})
});
}
#[cfg(feature = "client-tls")]
let scheme_check = if self.tls.is_some() {
require_https_scheme(&uri)
} else {
require_http_scheme(&uri)
};
#[cfg(not(feature = "client-tls"))]
let scheme_check = require_http_scheme(&uri);
if let Err(e) = scheme_check {
return Box::pin(async move { Err(e.into()) });
}
#[cfg(feature = "client-tls")]
let tls = self.tls.clone();
#[cfg(feature = "client-tls")]
let server_name = match self.tls.is_some() {
true => Some(match server_name_from_uri(&uri) {
Ok(sn) => sn,
Err(e) => return Box::pin(async move { Err(e.into()) }),
}),
false => None,
};
let connect_fut = <_ as tower::Service<Uri>>::call(&mut self.connector, uri);
let builder = self.builder.clone();
Box::pin(async move {
let io = connect_fut.await.map_err(Into::<BoxError>::into)?;
#[cfg(feature = "client-tls")]
let io: BoxedIo = if let (Some(tls), Some(server_name)) = (tls, server_name) {
let tcp = io.into_inner();
let connector = tokio_rustls::TlsConnector::from(tls);
let tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| ConnectError::unavailable(format!("TLS handshake failed: {e}")))?;
let (_, session) = tls_stream.get_ref();
if session.alpn_protocol() != Some(b"h2") {
return Err(ConnectError::unavailable(
"TLS handshake succeeded but server did not negotiate \
HTTP/2 via ALPN (is the server h2-capable?)",
)
.into());
}
Box::pin(hyper_util::rt::TokioIo::new(tls_stream))
} else {
Box::pin(io)
};
let (send_request, conn) = builder.handshake(io).await?;
tokio::spawn(async move {
if let Err(e) = conn.await {
tracing::debug!("h2 connection task exited with error: {e}");
}
});
Ok(SendRequest {
inner: send_request,
})
})
}
}
struct Reconnect<M>
where
M: tower::Service<Uri>,
{
make: M,
uri: Uri,
state: ReconnectState<M::Future, M::Response>,
deferred_error: Option<BoxError>,
has_connected: bool,
lazy: bool,
}
enum ReconnectState<F, S> {
Idle,
Connecting(Pin<Box<F>>),
Connected(S),
}
impl<M> Reconnect<M>
where
M: tower::Service<Uri>,
{
fn new(make: M, uri: Uri, lazy: bool) -> Self {
Self {
make,
uri,
state: ReconnectState::Idle,
deferred_error: None,
has_connected: false,
lazy,
}
}
}
impl<M, S> Reconnect<M>
where
M: tower::Service<Uri, Response = S>,
M::Error: Into<BoxError>,
S: tower::Service<Request<ClientBody>>,
S::Error: Into<BoxError>,
S::Future: Send + 'static,
{
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
if self.deferred_error.is_some() {
return Poll::Ready(Ok(()));
}
loop {
match &mut self.state {
ReconnectState::Idle => {
if let Err(e) = futures::ready!(self.make.poll_ready(cx)) {
return Poll::Ready(Err(e.into()));
}
let fut = self.make.call(self.uri.clone());
self.state = ReconnectState::Connecting(Box::pin(fut));
}
ReconnectState::Connecting(fut) => match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(svc)) => {
self.state = ReconnectState::Connected(svc);
self.has_connected = true;
}
Poll::Ready(Err(e)) => {
let e: BoxError = e.into();
self.state = ReconnectState::Idle;
if self.has_connected || self.lazy {
tracing::debug!("h2 reconnect failed (will retry): {e}");
self.deferred_error = Some(e);
return Poll::Ready(Ok(()));
} else {
return Poll::Ready(Err(e));
}
}
},
ReconnectState::Connected(svc) => match svc.poll_ready(cx) {
Poll::Ready(Ok(())) => return Poll::Ready(Ok(())),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(_)) => {
tracing::debug!("h2 connection lost; reconnecting");
self.state = ReconnectState::Idle;
}
},
}
}
}
fn call(
&mut self,
req: Request<ClientBody>,
) -> BoxFuture<'static, Result<S::Response, BoxError>> {
if let Some(e) = self.deferred_error.take() {
return Box::pin(async move { Err(e) });
}
match &mut self.state {
ReconnectState::Connected(svc) => {
let fut = svc.call(req);
Box::pin(async move { fut.await.map_err(Into::into) })
}
_ => {
Box::pin(async {
Err("Http2Connection::call before poll_ready returned Ready"
.to_string()
.into())
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lazy_plaintext_starts_idle() {
let conn = Http2Connection::lazy_plaintext("http://localhost:0".parse().unwrap());
let _ = conn;
}
#[tokio::test]
async fn connect_plaintext_to_nonexistent_fails() {
let err = Http2Connection::connect_plaintext("http://127.0.0.1:1".parse().unwrap()).await;
assert!(err.is_err(), "expected connect to port 1 to fail");
}
#[tokio::test]
async fn connect_plaintext_rejects_https() {
let err = Http2Connection::connect_plaintext("https://localhost:8080".parse().unwrap())
.await
.unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
assert!(err.message.as_deref().unwrap().contains("http://"));
}
#[test]
fn require_http_scheme_cases() {
assert!(require_http_scheme(&"http://foo".parse().unwrap()).is_ok());
assert!(require_http_scheme(&"/path".parse().unwrap()).is_ok());
assert!(require_http_scheme(&"https://foo".parse().unwrap()).is_err());
}
#[cfg(feature = "client-tls")]
#[test]
fn require_https_scheme_cases() {
assert!(require_https_scheme(&"https://foo".parse().unwrap()).is_ok());
assert!(require_https_scheme(&"http://foo".parse().unwrap()).is_err());
assert!(require_https_scheme(&"/path".parse().unwrap()).is_err());
}
#[cfg(feature = "client-tls")]
#[test]
fn prepare_tls_for_h2_sets_alpn() {
let cfg = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let prepared = prepare_tls_for_h2(&cfg);
assert_eq!(prepared.alpn_protocols, vec![b"h2".to_vec()]);
}
#[cfg(feature = "client-tls")]
#[test]
fn prepare_tls_for_h2_shares_cert_resolver() {
let cfg = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let prepared = prepare_tls_for_h2(&cfg);
assert!(Arc::ptr_eq(
&cfg.client_auth_cert_resolver,
&prepared.client_auth_cert_resolver
));
}
#[cfg(feature = "client-tls")]
#[test]
fn server_name_from_uri_extracts_host() {
let name = server_name_from_uri(&"https://example.com:8080/path".parse().unwrap()).unwrap();
assert_eq!(format!("{name:?}"), "DnsName(\"example.com\")");
}
#[cfg(feature = "client-tls")]
#[test]
fn server_name_from_uri_ipv4() {
let name = server_name_from_uri(&"https://10.0.0.1:8443".parse().unwrap()).unwrap();
assert!(matches!(name, rustls_pki_types::ServerName::IpAddress(_)));
}
#[cfg(feature = "client-tls")]
#[test]
fn server_name_from_uri_ipv6_strips_brackets() {
let name = server_name_from_uri(&"https://[::1]:8443".parse().unwrap()).unwrap();
assert!(matches!(name, rustls_pki_types::ServerName::IpAddress(_)));
}
#[cfg(feature = "client-tls")]
#[tokio::test]
async fn connect_tls_rejects_http_scheme() {
let cfg = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let result =
Http2Connection::connect_tls("http://localhost:8080".parse().unwrap(), cfg).await;
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected http:// to be rejected"),
};
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn lazy_with_connector_starts_idle() {
let conn = Http2Connection::lazy_with_connector(
tower::service_fn(|_uri: Uri| async {
Err::<hyper_util::rt::TokioIo<tokio::net::TcpStream>, _>(std::io::Error::other(
"unreachable",
))
}),
"http://localhost".parse().unwrap(),
);
let _ = conn;
}
#[tokio::test]
async fn connect_with_connector_propagates_error() {
let err = Http2Connection::connect_with_connector(
tower::service_fn(|_uri: Uri| async {
Err::<hyper_util::rt::TokioIo<tokio::net::TcpStream>, _>(std::io::Error::other(
"dial refused",
))
}),
"http://localhost".parse().unwrap(),
)
.await
.unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::Unavailable);
assert!(
err.message.as_deref().unwrap().contains("dial refused"),
"error should propagate connector message, got: {err:?}"
);
}
#[cfg(unix)]
#[test]
fn lazy_unix_starts_idle() {
let conn = Http2Connection::lazy_unix(
"/nonexistent/test.sock",
"http://localhost".parse().unwrap(),
);
let _ = conn;
}
#[cfg(unix)]
#[tokio::test]
async fn connect_unix_nonexistent_fails() {
let path = "/nonexistent/buffa-test.sock";
let err = Http2Connection::connect_unix(path, "http://localhost".parse().unwrap())
.await
.unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::Unavailable);
assert!(
err.message.as_deref().unwrap().contains(path),
"error should include socket path, got: {err:?}"
);
}
}