use std::future::Future;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::Path;
use futures_util::TryFuture;
use crate::filter::Filter;
use crate::reject::IsReject;
use crate::reply::Reply;
#[cfg(feature = "tls")]
use crate::tls::TlsConfigBuilder;
pub fn serve<F>(filter: F) -> Server<F, accept::LazyTcp, run::Standard>
where
F: Filter + Clone + Send + Sync + 'static,
F::Extract: Reply,
F::Error: IsReject,
{
Server {
acceptor: accept::LazyTcp,
pipeline: false,
filter,
runner: run::Standard,
}
}
#[derive(Debug)]
pub struct Server<F, A, R> {
acceptor: A,
filter: F,
pipeline: bool,
runner: R,
}
impl<F, R> Server<F, accept::LazyTcp, R>
where
F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: Reply,
<F::Future as TryFuture>::Error: IsReject,
R: run::Run,
{
pub async fn run(self, addr: impl Into<SocketAddr>) {
self.bind(addr).await.run().await;
}
pub async fn bind(self, addr: impl Into<SocketAddr>) -> Server<F, tokio::net::TcpListener, R> {
let addr = addr.into();
let acceptor = tokio::net::TcpListener::bind(addr)
.await
.expect("failed to bind to address");
self.incoming(acceptor)
}
pub fn incoming<A>(self, acceptor: A) -> Server<F, A, R> {
Server {
acceptor,
filter: self.filter,
pipeline: self.pipeline,
runner: self.runner,
}
}
}
impl<F, A, R> Server<F, A, R>
where
F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: Reply,
<F::Future as TryFuture>::Error: IsReject,
A: accept::Accept,
R: run::Run,
{
#[cfg(feature = "tls")]
pub fn tls(self) -> Server<F, accept::Tls<A>, R> {}
pub fn graceful<Fut>(self, shutdown_signal: Fut) -> Server<F, A, run::Graceful<Fut>>
where
Fut: Future<Output = ()> + Send + 'static,
{
Server {
acceptor: self.acceptor,
filter: self.filter,
pipeline: self.pipeline,
runner: run::Graceful(shutdown_signal),
}
}
pub async fn run(self) {
R::run(self).await;
}
}
#[cfg(feature = "tls")]
impl<F, A, R> Server<F, accept::Tls<A>, R>
where
F: Filter + Clone + Send + Sync + 'static,
<F::Future as TryFuture>::Ok: Reply,
<F::Future as TryFuture>::Error: IsReject,
A: accept::Accept,
R: run::Run,
{
pub fn key_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.key_path(path))
}
pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.cert_path(path))
}
pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.client_auth_optional_path(path))
}
pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.client_auth_required_path(path))
}
pub fn key(self, key: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.key(key.as_ref()))
}
pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.cert(cert.as_ref()))
}
pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
}
pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
}
pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
}
fn with_tls<Func>(self, func: Func) -> Self
where
Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
{
let tls = func(tls);
}
}
mod accept {
use std::net::SocketAddr;
pub trait Accept {
type IO: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static;
type AcceptError: std::fmt::Debug;
type Accepting: super::Future<Output = Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>
+ Send
+ 'static;
#[allow(async_fn_in_trait)]
async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error>;
}
#[derive(Debug)]
pub struct LazyTcp;
impl Accept for tokio::net::TcpListener {
type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
type AcceptError = std::convert::Infallible;
type Accepting =
std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
let (io, addr) = <tokio::net::TcpListener>::accept(self).await?;
Ok(std::future::ready(Ok((
hyper_util::rt::TokioIo::new(io),
Some(addr),
))))
}
}
#[cfg(unix)]
impl Accept for tokio::net::UnixListener {
type IO = hyper_util::rt::TokioIo<tokio::net::UnixStream>;
type AcceptError = std::convert::Infallible;
type Accepting =
std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
let (io, _addr) = <tokio::net::UnixListener>::accept(self).await?;
Ok(std::future::ready(Ok((
hyper_util::rt::TokioIo::new(io),
None,
))))
}
}
#[cfg(feature = "tls")]
#[derive(Debug)]
pub struct Tls<A>(pub(super) A);
#[cfg(feature = "tls")]
impl<A: Accept> Accept for Tls<A> {
type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
type AcceptError = std::convert::Infallible;
type Accepting =
std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
let (io, addr) = self.0.accept().await?;
Ok(std::future::ready(Ok((
hyper_util::rt::TokioIo::new(io),
addr,
))))
}
}
}
mod middleware {
use std::net::SocketAddr;
use std::task::{Context, Poll};
use tower_service::Service;
use crate::filters::addr::RemoteAddr;
#[derive(Clone, Debug)]
pub(super) struct RemoteAddrService<S> {
inner: S,
remote_addr: Option<SocketAddr>,
}
impl<S> RemoteAddrService<S> {
pub(super) fn new(inner: S, remote_addr: Option<SocketAddr>) -> Self {
Self { inner, remote_addr }
}
}
impl<S, B> Service<http::Request<B>> for RemoteAddrService<S>
where
S: Service<http::Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
if let Some(addr) = self.remote_addr {
req.extensions_mut().insert(RemoteAddr(addr));
}
self.inner.call(req)
}
}
}
mod run {
pub trait Run {
#[allow(async_fn_in_trait)]
async fn run<F, A>(server: super::Server<F, A, Self>)
where
F: super::Filter + Clone + Send + Sync + 'static,
<F::Future as super::TryFuture>::Ok: super::Reply,
<F::Future as super::TryFuture>::Error: super::IsReject,
A: super::accept::Accept,
Self: Sized;
}
#[derive(Debug)]
pub struct Standard;
impl Run for Standard {
async fn run<F, A>(mut server: super::Server<F, A, Self>)
where
F: super::Filter + Clone + Send + Sync + 'static,
<F::Future as super::TryFuture>::Ok: super::Reply,
<F::Future as super::TryFuture>::Error: super::IsReject,
A: super::accept::Accept,
Self: Sized,
{
let pipeline = server.pipeline;
loop {
let accepting = match server.acceptor.accept().await {
Ok(fut) => fut,
Err(err) => {
handle_accept_error(err).await;
continue;
}
};
let svc = crate::service(server.filter.clone());
tokio::spawn(async move {
let (io, remote_addr) = match accepting.await {
Ok(pair) => pair,
Err(err) => {
tracing::debug!("server accept error: {:?}", err);
return;
}
};
let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
let svc = hyper_util::service::TowerToHyperService::new(svc);
if let Err(err) = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.http1()
.pipeline_flush(pipeline)
.serve_connection_with_upgrades(io, svc)
.await
{
tracing::error!("server connection error: {:?}", err)
}
});
}
}
}
#[derive(Debug)]
pub struct Graceful<Fut>(pub(super) Fut);
impl<Fut> Run for Graceful<Fut>
where
Fut: super::Future<Output = ()> + Send + 'static,
{
async fn run<F, A>(mut server: super::Server<F, A, Self>)
where
F: super::Filter + Clone + Send + Sync + 'static,
<F::Future as super::TryFuture>::Ok: super::Reply,
<F::Future as super::TryFuture>::Error: super::IsReject,
A: super::accept::Accept,
Self: Sized,
{
use futures_util::future;
let pipeline = server.pipeline;
let graceful_util = hyper_util::server::graceful::GracefulShutdown::new();
let mut shutdown_signal = std::pin::pin!(server.runner.0);
loop {
let accept = std::pin::pin!(server.acceptor.accept());
let accepting = match future::select(accept, &mut shutdown_signal).await {
future::Either::Left((Ok(fut), _)) => fut,
future::Either::Left((Err(err), _)) => {
handle_accept_error(err).await;
continue;
}
future::Either::Right(((), _)) => {
tracing::debug!("shutdown signal received, starting graceful shutdown");
break;
}
};
let svc = crate::service(server.filter.clone());
let watcher = graceful_util.watcher();
tokio::spawn(async move {
let (io, remote_addr) = match accepting.await {
Ok(pair) => pair,
Err(err) => {
tracing::debug!("server accepting error: {:?}", err);
return;
}
};
let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
let svc = hyper_util::service::TowerToHyperService::new(svc);
let mut hyper = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
);
hyper.http1().pipeline_flush(pipeline);
let conn = hyper.serve_connection_with_upgrades(io, svc);
let conn = watcher.watch(conn);
if let Err(err) = conn.await {
tracing::error!("server connection error: {:?}", err)
}
});
}
drop(server.acceptor); graceful_util.shutdown().await;
}
}
async fn handle_accept_error(e: std::io::Error) {
if is_connection_error(&e) {
return;
}
tracing::error!("accept error: {:?}", e);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
fn is_connection_error(e: &std::io::Error) -> bool {
matches!(
e.kind(),
std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
)
}
}