use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use hyper_util::service::TowerToHyperService;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tower::{Service, ServiceExt};
use crate::{BodyError, BoxBody, BoxError};
pub async fn serve<S>(addr: impl Into<SocketAddr>, service: S) -> Result<(), BoxError>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
let address = addr.into();
let listener = TcpListener::bind(address)
.await
.map_err(|err| -> BoxError { Box::new(err) })?;
#[cfg(feature = "tracing")]
tracing::info!(addr = %listener.local_addr().unwrap_or(address), "listening for HTTP connections");
let mut shutdown = Box::pin(shutdown_signal());
let mut tasks = JoinSet::new();
loop {
tokio::select! {
_ = &mut shutdown => {
#[cfg(feature = "tracing")]
tracing::info!("shutdown signal received");
break;
}
accept_result = listener.accept() => {
let (stream, peer_addr) = accept_result.map_err(|err| -> BoxError { Box::new(err) })?;
#[cfg(not(feature = "tracing"))]
let _ = &peer_addr;
let service = service.clone();
tasks.spawn(async move {
let io = TokioIo::new(stream);
let handler = ConnectionHandler::new(service);
let hyper_service = TowerToHyperService::new(handler);
if let Err(error) = http1::Builder::new().serve_connection(io, hyper_service).await {
#[cfg(feature = "tracing")]
tracing::error!(?peer_addr, ?error, "error serving connection");
#[cfg(not(feature = "tracing"))]
let _ = error;
}
});
}
join_result = tasks.join_next(), if !tasks.is_empty() => {
if let Some(Err(join_error)) = join_result {
#[cfg(feature = "tracing")]
tracing::error!(?join_error, "connection task failed");
#[cfg(not(feature = "tracing"))]
let _ = join_error;
}
}
}
}
while let Some(result) = tasks.join_next().await {
if let Err(join_error) = result {
#[cfg(feature = "tracing")]
tracing::error!(?join_error, "connection task failed during shutdown");
#[cfg(not(feature = "tracing"))]
let _ = join_error;
}
}
#[cfg(feature = "tracing")]
tracing::info!("server shutdown complete");
Ok(())
}
struct ConnectionHandler<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
service: S,
}
impl<S> ConnectionHandler<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
fn new(service: S) -> Self {
Self { service }
}
}
impl<S> Clone for ConnectionHandler<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
}
}
}
impl<S> Service<Request<Incoming>> for ConnectionHandler<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match self.service.poll_ready(cx) {
std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
std::task::Poll::Ready(Err(err)) => match err {},
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
fn call(&mut self, req: Request<Incoming>) -> Self::Future {
let service = self.service.clone();
let fut = async move {
let (parts, body) = req.into_parts();
let boxed_body: BoxBody = body.map_err(|err| -> BodyError { err }).boxed();
let request = Request::from_parts(parts, boxed_body);
let mut svc = service;
if let Err(err) = svc.ready().await {
match err {}
}
match svc.call(request).await {
Ok(response) => Ok::<_, BoxError>(response),
Err(err) => match err {},
}
};
Box::pin(fut)
}
}
fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
use tokio::signal::unix::{signal, SignalKind};
signal(SignalKind::terminate())
.expect("failed to install terminate signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
})
}