#![cfg_attr(
not(any(feature = "rustls-tls", feature = "openssl-tls")),
allow(dead_code, unused_variables)
)]
use std::{convert::Infallible, net::SocketAddr, path::PathBuf, str::FromStr, sync::Arc};
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tower::Service;
use tracing::{debug, error, info, info_span, Instrument};
#[cfg(feature = "rustls-tls")]
mod tls_rustls;
#[cfg(feature = "openssl-tls")]
mod tls_openssl;
#[cfg(test)]
mod tests;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct ServerArgs {
#[cfg_attr(feature = "clap", clap(long, default_value = "0.0.0.0:443"))]
pub server_addr: SocketAddr,
#[cfg_attr(feature = "clap", clap(long))]
pub server_tls_key: Option<TlsKeyPath>,
#[cfg_attr(feature = "clap", clap(long))]
pub server_tls_certs: Option<TlsCertPath>,
}
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct Bound {
local_addr: SocketAddr,
tcp: tokio::net::TcpListener,
tls: Arc<TlsPaths>,
}
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct SpawnedServer {
local_addr: SocketAddr,
task: tokio::task::JoinHandle<()>,
}
#[derive(Debug, Error)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
#[non_exhaustive]
pub enum Error {
#[error("--server-tls-key must be set")]
NoTlsKey,
#[error("--server-tls-certs must be set")]
NoTlsCerts,
#[error("failed to read TLS certificates: {0}")]
TlsCertsReadError(#[source] std::io::Error),
#[error("failed to read TLS key: {0}")]
TlsKeyReadError(#[source] std::io::Error),
#[error("failed to load TLS credentials: {0}")]
InvalidTlsCredentials(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("failed to bind {0:?}: {1}")]
Bind(SocketAddr, #[source] std::io::Error),
#[error("failed to get bound local address: {0}")]
LocalAddr(#[source] std::io::Error),
}
#[derive(Clone, Debug)]
pub struct TlsKeyPath(PathBuf);
#[derive(Clone, Debug)]
pub struct TlsCertPath(PathBuf);
#[derive(Clone, Debug)]
struct TlsPaths {
key: TlsKeyPath,
certs: TlsCertPath,
}
impl ServerArgs {
pub async fn bind(self) -> Result<Bound, Error> {
let tls = {
let key = self.server_tls_key.ok_or(Error::NoTlsKey)?;
let certs = self.server_tls_certs.ok_or(Error::NoTlsCerts)?;
#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
let _ = tls_openssl::load_tls(&key, &certs).await?;
#[cfg(feature = "rustls-tls")]
let _ = tls_rustls::load_tls(&key, &certs).await?;
Arc::new(TlsPaths { key, certs })
};
let tcp = TcpListener::bind(&self.server_addr)
.await
.map_err(|e| Error::Bind(self.server_addr, e))?;
let local_addr = tcp.local_addr().map_err(Error::LocalAddr)?;
Ok(Bound {
local_addr,
tcp,
tls,
})
}
}
impl Bound {
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn spawn<S, B>(self, service: S, drain: drain::Watch) -> SpawnedServer
where
S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
+ Clone
+ Send
+ 'static,
S::Error: std::error::Error + Send + Sync,
S::Future: Send,
B: hyper::body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync,
{
let Self {
local_addr,
tcp,
tls,
} = self;
let task = tokio::spawn(
accept_loop(tcp, drain, service, tls)
.instrument(info_span!("server", port = %local_addr.port())),
);
SpawnedServer { local_addr, task }
}
}
impl SpawnedServer {
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn abort(&self) {
self.task.abort();
}
pub async fn join(self) -> Result<(), tokio::task::JoinError> {
self.task.await
}
}
async fn accept_loop<S, B>(tcp: TcpListener, drain: drain::Watch, service: S, tls: Arc<TlsPaths>)
where
S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
+ Clone
+ Send
+ 'static,
S::Error: std::error::Error + Send + Sync,
S::Future: Send,
B: hyper::body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync,
{
tracing::debug!("listening");
loop {
tracing::trace!("accepting");
let socket = tokio::select! {
biased;
release = drain.clone().signaled() => {
drop(release);
return;
}
res = tcp.accept() => match res {
Ok((socket, _)) => socket,
Err(error) => {
error!(%error, "Failed to accept connection");
continue;
}
},
};
if let Err(error) = socket.set_nodelay(true) {
error!(%error, "Failed to set TCP_NODELAY");
continue;
}
let client_addr = match socket.peer_addr() {
Ok(addr) => addr,
Err(error) => {
error!(%error, "Failed to get peer address");
continue;
}
};
tokio::spawn(
serve_conn(socket, drain.clone(), service.clone(), tls.clone()).instrument(info_span!(
"conn",
client.ip = %client_addr.ip(),
client.port = %client_addr.port(),
)),
);
}
}
async fn serve_conn<S, B>(socket: TcpStream, drain: drain::Watch, service: S, tls: Arc<TlsPaths>)
where
S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
+ Clone
+ Send
+ 'static,
S::Error: std::error::Error + Send + Sync,
S::Future: Send,
B: hyper::body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync,
{
tracing::debug!("accepted TCP connection");
let socket = {
let TlsPaths { ref key, ref certs } = &*tls;
#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
let res = tls_openssl::load_tls(key, certs).await;
#[cfg(feature = "rustls-tls")]
let res = tls_rustls::load_tls(key, certs).await;
#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
let res = {
enum Accept {}
Err::<Accept, _>(std::io::Error::other("TLS support not enabled"))
};
let tls = match res {
Ok(tls) => tls,
Err(error) => {
info!(%error, "Connection failed");
return;
}
};
tracing::trace!("loaded TLS credentials");
#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
let res = tls_openssl::accept(&tls, socket).await;
#[cfg(feature = "rustls-tls")]
let res = tls_rustls::accept(&tls, socket).await;
#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
let res = Err::<TcpStream, _>(std::io::Error::other("TLS support not enabled"));
let socket = match res {
Ok(s) => s,
Err(error) => {
info!(%error, "TLS handshake failed");
return;
}
};
tracing::trace!("TLS handshake completed");
socket
};
#[derive(Copy, Clone, Debug)]
struct Executor;
impl<F> hyper::rt::Executor<F> for Executor
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
fn execute(&self, fut: F) {
tokio::spawn(fut.in_current_span());
}
}
#[cfg(any(feature = "server-brotli", feature = "server-gzip"))]
let service = tower_http::decompression::Decompression::new(
tower_http::compression::Compression::new(service),
);
let mut builder = hyper_util::server::conn::auto::Builder::new(Executor);
builder
.http1()
.header_read_timeout(std::time::Duration::from_secs(2))
.timer(hyper_util::rt::TokioTimer::default());
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
let conn = graceful.watch(
builder
.serve_connection(
hyper_util::rt::TokioIo::new(socket),
hyper_util::service::TowerToHyperService::new(service),
)
.into_owned(),
);
tokio::spawn(
async move {
match conn.await {
Ok(()) => debug!("Connection closed"),
Err(error) => info!(%error, "Connection lost"),
}
}
.in_current_span(),
);
let latch = drain.signaled().await;
latch.release_after(graceful.shutdown()).await;
}
impl FromStr for TlsCertPath {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse().map(Self)
}
}
impl FromStr for TlsKeyPath {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse().map(Self)
}
}