use std::{
convert::Infallible,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::server::conn::auto::Builder;
use pin_project_lite::pin_project;
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::util::{Oneshot, ServiceExt};
use tower_service::Service;
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
Serve {
tcp_listener,
make_service,
_marker: PhantomData,
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<M, S> {
tcp_listener: TcpListener,
make_service: M,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<M, S> {
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
make_service: self.make_service,
signal,
_marker: PhantomData,
}
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
where
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
_marker: _,
} = self;
f.debug_struct("Serve")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
fn into_future(self) -> Self::IntoFuture {
private::ServeFuture(Box::pin(async move {
let Self {
tcp_listener,
mut make_service,
_marker: _,
} = self;
loop {
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};
let tcp_stream = TokioIo::new(tcp_stream);
poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});
let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});
let hyper_service = TowerToHyperService {
service: tower_service,
};
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
}
}
});
}
}))
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
make_service: M,
signal: F,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
where
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
signal,
_marker: _,
} = self;
f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut make_service,
signal,
_marker: _,
} = self;
let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});
let (close_tx, close_rx) = watch::channel(());
private::ServeFuture(Box::pin(async move {
loop {
let (tcp_stream, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => {
match conn {
Some(conn) => conn,
None => continue,
}
}
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
let tcp_stream = TokioIo::new(tcp_stream);
trace!("connection {remote_addr} accepted");
poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});
let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});
let hyper_service = TowerToHyperService {
service: tower_service,
};
let signal_tx = Arc::clone(&signal_tx);
let close_rx = close_rx.clone();
tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);
loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}
trace!("connection {remote_addr} closed");
drop(close_rx);
});
}
drop(close_rx);
drop(tcp_listener);
trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;
Ok(())
}))
}
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
match listener.accept().await {
Ok(conn) => Some(conn),
Err(e) => {
if is_connection_error(&e) {
return None;
}
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
}
}
mod private {
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
impl Future for ServeFuture {
type Output = io::Result<()>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
impl std::fmt::Debug for ServeFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeFuture").finish_non_exhaustive()
}
}
}
#[derive(Debug, Copy, Clone)]
struct TowerToHyperService<S> {
service: S,
}
impl<S> hyper::service::Service<Request<Incoming>> for TowerToHyperService<S>
where
S: tower_service::Service<Request> + Clone,
{
type Response = S::Response;
type Error = S::Error;
type Future = TowerToHyperServiceFuture<S, Request>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let req = req.map(Body::new);
TowerToHyperServiceFuture {
future: self.service.clone().oneshot(req),
}
}
}
pin_project! {
struct TowerToHyperServiceFuture<S, R>
where
S: tower_service::Service<R>,
{
#[pin]
future: Oneshot<S, R>,
}
}
impl<S, R> Future for TowerToHyperServiceFuture<S, R>
where
S: tower_service::Service<R>,
{
type Output = Result<S::Response, S::Error>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().future.poll(cx)
}
}
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.inner().local_addr()
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
handler::{Handler, HandlerWithoutStateExt},
routing::get,
Router,
};
#[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() {
let router: Router = Router::new();
let addr = "0.0.0.0:0";
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
router.into_make_service_with_connect_info::<SocketAddr>(),
);
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<SocketAddr>(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(),
);
}
async fn handler() {}
}