use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use http::Uri;
use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
use tower_service::Service;
use crate::{TlsAcceptor, TlsConnector, TlsStream};
#[derive(Clone, Debug)]
pub struct HyperAcceptor {
inner: TlsAcceptor,
}
impl HyperAcceptor {
#[must_use]
pub fn new(acceptor: TlsAcceptor) -> Self {
Self { inner: acceptor }
}
pub async fn accept(&self, tcp: TcpStream) -> crate::error::Result<TokioIo<TlsStream>> {
let stream = self.inner.accept(tcp).await?;
Ok(TokioIo::new(stream))
}
#[must_use]
pub fn inner(&self) -> &TlsAcceptor {
&self.inner
}
}
impl From<TlsAcceptor> for HyperAcceptor {
fn from(acceptor: TlsAcceptor) -> Self {
Self::new(acceptor)
}
}
#[derive(Clone, Debug)]
pub struct HttpsConnector {
http: HttpConnector,
tls: TlsConnector,
}
impl HttpsConnector {
#[must_use]
pub fn new(tls: TlsConnector) -> Self {
let mut http = HttpConnector::new();
http.enforce_http(false);
Self { http, tls }
}
#[must_use]
pub fn with_http_connector(http: HttpConnector, tls: TlsConnector) -> Self {
Self { http, tls }
}
#[must_use]
pub fn tls(&self) -> &TlsConnector {
&self.tls
}
#[must_use]
pub fn http(&self) -> &HttpConnector {
&self.http
}
}
impl From<TlsConnector> for HttpsConnector {
fn from(tls: TlsConnector) -> Self {
Self::new(tls)
}
}
type ConnFuture = Pin<Box<dyn Future<Output = io::Result<TokioIo<TlsStream>>> + Send>>;
impl Service<Uri> for HttpsConnector {
type Response = TokioIo<TlsStream>;
type Error = io::Error;
type Future = ConnFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Service::<Uri>::poll_ready(&mut self.http, cx).map_err(io::Error::other)
}
fn call(&mut self, uri: Uri) -> Self::Future {
let scheme = uri.scheme_str().map(str::to_owned);
let host = uri.host().map(str::to_owned);
let mut http = self.http.clone();
let tls = self.tls.clone();
Box::pin(async move {
if scheme.as_deref() != Some("https") {
return Err(io::Error::other(format!(
"HttpsConnector requires https://; got scheme {scheme:?}"
)));
}
let host = host.ok_or_else(|| io::Error::other("URI is missing a host"))?;
let tcp_io = Service::<Uri>::call(&mut http, uri)
.await
.map_err(io::Error::other)?;
let tcp = tcp_io.into_inner();
let tls_stream = tls.connect(&host, tcp).await.map_err(io::Error::other)?;
Ok(TokioIo::new(tls_stream))
})
}
}
impl Connection for TlsStream {
fn connected(&self) -> Connected {
let info = Connected::new();
if self.negotiated().alpn().is_some_and(|a| a == b"h2") {
info.negotiated_h2()
} else {
info
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::{ClientConfig, ServerConfig};
const TEST_CERT_PEM: &[u8] = include_bytes!("../tests/data/cert.pem");
const TEST_KEY_PEM: &[u8] = include_bytes!("../tests/data/key.pem");
fn test_acceptor() -> TlsAcceptor {
let cfg = Arc::new(
ServerConfig::builder()
.with_pem_bytes(TEST_CERT_PEM, TEST_KEY_PEM)
.expect("ServerConfig builds"),
);
TlsAcceptor::new(cfg)
}
fn test_connector() -> TlsConnector {
let cfg = Arc::new(
ClientConfig::builder()
.with_root_certs_pem_bytes(TEST_CERT_PEM)
.build()
.expect("ClientConfig builds"),
);
TlsConnector::new(cfg)
}
#[test]
fn hyper_acceptor_from_and_inner() {
let a = test_acceptor();
let h: HyperAcceptor = a.into();
let _: TlsAcceptor = h.inner().clone();
let _ = h.clone();
}
#[test]
fn https_connector_default_constructor() {
let c = HttpsConnector::new(test_connector());
let _: &TlsConnector = c.tls();
let _: &HttpConnector = c.http();
let _ = c.clone();
}
#[test]
fn https_connector_with_custom_http_connector() {
let mut http = HttpConnector::new();
http.enforce_http(false);
let c = HttpsConnector::with_http_connector(http, test_connector());
let _: &HttpConnector = c.http();
}
}