use std::future::Future;
use futures_core::future::BoxFuture;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tower::Service;
use tracing::Instrument;
use super::connection::ConnectionError;
use super::connection::HttpConnection;
use super::transport::TransportStream;
use super::Connection;
use crate::bridge::io::TokioIo;
use crate::info::HasConnectionInfo;
pub mod auto;
#[cfg(feature = "mocks")]
pub mod mock;
pub use hyper::client::conn::http1;
pub use hyper::client::conn::http2;
#[derive(Debug)]
pub struct ProtocolRequest<IO: HasConnectionInfo> {
pub transport: TransportStream<IO>,
pub version: HttpProtocol,
}
pub trait Protocol<IO>
where
IO: HasConnectionInfo,
Self: Service<ProtocolRequest<IO>, Response = Self::Connection>,
{
type Error: std::error::Error + Send + Sync + 'static;
type Connection: Connection;
type Future: Future<Output = Result<Self::Connection, <Self as Protocol<IO>>::Error>>
+ Send
+ 'static;
fn connect(
&mut self,
transport: TransportStream<IO>,
version: HttpProtocol,
) -> <Self as Protocol<IO>>::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), <Self as Protocol<IO>>::Error>>;
}
impl<T, C, IO> Protocol<IO> for T
where
IO: HasConnectionInfo,
T: Service<ProtocolRequest<IO>, Response = C> + Send + 'static,
T::Error: std::error::Error + Send + Sync + 'static,
T::Future: Send + 'static,
C: Connection,
{
type Error = T::Error;
type Connection = C;
type Future = T::Future;
fn connect(
&mut self,
transport: TransportStream<IO>,
version: HttpProtocol,
) -> <Self as Protocol<IO>>::Future {
self.call(ProtocolRequest { transport, version })
}
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), <Self as Protocol<IO>>::Error>> {
Service::poll_ready(self, cx)
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum HttpProtocol {
Http1,
Http2,
}
impl HttpProtocol {
pub fn multiplex(&self) -> bool {
matches!(self, Self::Http2)
}
pub fn version(&self) -> ::http::Version {
match self {
Self::Http1 => ::http::Version::HTTP_11,
Self::Http2 => ::http::Version::HTTP_2,
}
}
}
impl From<::http::Version> for HttpProtocol {
fn from(version: ::http::Version) -> Self {
match version {
::http::Version::HTTP_11 | ::http::Version::HTTP_10 => Self::Http1,
::http::Version::HTTP_2 => Self::Http2,
_ => panic!("Unsupported HTTP protocol"),
}
}
}
impl<IO> tower::Service<ProtocolRequest<IO>> for hyper::client::conn::http1::Builder
where
IO: HasConnectionInfo + AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Response = HttpConnection;
type Error = ConnectionError;
type Future = BoxFuture<'static, Result<HttpConnection, ConnectionError>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: ProtocolRequest<IO>) -> Self::Future {
let builder = self.clone();
let stream = req.transport.into_inner();
let info = stream.info();
let span = tracing::info_span!("connection", version=?http::Version::HTTP_11, peer=%info.remote_addr());
Box::pin(async move {
let (sender, conn) = builder
.handshake(TokioIo::new(stream))
.await
.map_err(|err| ConnectionError::Handshake(err.into()))?;
tokio::spawn(
async {
if let Err(err) = conn.await {
if err.is_user() {
tracing::error!(err = format!("{err:#}"), "h1 connection driver error");
} else {
tracing::debug!(err = format!("{err:#}"), "h1 connection driver error");
}
}
}
.instrument(span),
);
Ok(HttpConnection::h1(sender))
})
}
}
impl<E, IO> tower::Service<ProtocolRequest<IO>> for hyper::client::conn::http2::Builder<E>
where
E: hyper::rt::bounds::Http2ClientConnExec<crate::body::Body, TokioIo<IO>>
+ Unpin
+ Send
+ Sync
+ Clone
+ 'static,
IO: HasConnectionInfo + AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Response = HttpConnection;
type Error = ConnectionError;
type Future = BoxFuture<'static, Result<HttpConnection, ConnectionError>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: ProtocolRequest<IO>) -> Self::Future {
let builder = self.clone();
let stream = req.transport.into_inner();
let info = stream.info();
let span = tracing::info_span!("connection", version=?http::Version::HTTP_11, peer=%info.remote_addr());
Box::pin(async move {
let (sender, conn) = builder
.handshake(TokioIo::new(stream))
.await
.map_err(|err| ConnectionError::Handshake(err.into()))?;
tokio::spawn(
async {
if let Err(err) = conn.await {
if err.is_user() {
tracing::error!(err = format!("{err:#}"), "h2 connection driver error");
} else {
tracing::debug!(err = format!("{err:#}"), "h2 connection driver error");
}
}
}
.instrument(span),
);
Ok(HttpConnection::h2(sender))
})
}
}