use std::net::SocketAddr;
use hyper::body::{Body, Incoming};
use hyper::server::conn::http1;
use hyper::service::HttpService;
use hyper_util::rt::TokioIo;
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use crate::stream::HttpOrHttpsStream;
pub struct HttpOrHttpsAcceptor {
listener: TcpListener,
tls: Option<TlsAcceptor>,
}
impl HttpOrHttpsAcceptor {
pub const fn new(listener: TcpListener) -> Self {
Self {
listener,
tls: None,
}
}
#[must_use]
pub fn with_tls(mut self, tls: TlsAcceptor) -> Self {
self.tls = Some(tls);
self
}
pub async fn accept<S>(
&self,
service: S,
) -> Result<
(
SocketAddr,
impl Future<Output = Result<(), AcceptorError>> + use<S>,
),
AcceptorError,
>
where
S: HttpService<Incoming> + 'static,
<S::ResBody as Body>::Error: std::error::Error + Send + Sync,
{
match self.listener.accept().await {
Ok((stream, peer_addr)) => {
let cloned_tls = self.tls.clone();
let conn_fut = handle_conn(stream, cloned_tls, service);
Ok((peer_addr, conn_fut))
}
Err(e) => Err(AcceptorError::TcpConnect(e)),
}
}
}
async fn handle_conn<S>(
stream: TcpStream,
tls: Option<TlsAcceptor>,
handler: S,
) -> Result<(), AcceptorError>
where
S: HttpService<Incoming>,
S::ResBody: 'static,
<S::ResBody as Body>::Error: std::error::Error + Send + Sync,
{
let client = match tls {
None => HttpOrHttpsStream::Http(stream),
Some(tls) => {
let tls_stream = tls
.accept(stream)
.await
.map_err(AcceptorError::TlsHandshake)?;
HttpOrHttpsStream::Https(tls_stream)
}
};
http1::Builder::new()
.serve_connection(TokioIo::new(client), handler)
.with_upgrades()
.await
.map_err(AcceptorError::Hyper)
}
#[derive(Error, Debug)]
pub enum AcceptorError {
#[error("TCP connection to client failed")]
TcpConnect(#[source] std::io::Error),
#[error("TLS handshake with client failed")]
TlsHandshake(#[source] std::io::Error),
#[error("Failed to serve HTTP connection")]
Hyper(#[source] hyper::Error),
}