use std::{
net::{SocketAddr, TcpListener},
num::{NonZeroU32, NonZeroUsize},
path::Path,
time::Duration,
};
use axum_server::{
tls_rustls::{RustlsAcceptor, RustlsConfig},
Handle,
};
use hyper_util::server::conn::auto::Builder;
use serde::{Deserialize, Serialize};
use socket2::SockRef;
use thiserror::Error;
use tokio::{
net::{lookup_host, TcpSocket, ToSocketAddrs},
task::JoinHandle,
};
use tracing::{debug, debug_span, error, info, Instrument};
use crate::{
errors::IoError,
signal::{SignalError, SignalStream},
};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ServerBuilderError {
#[error("Unable to parse endpoint address: {0}")]
AddressParse(IoError),
#[error("Unable to resolve DNS name: {0}")]
Resolve(String),
#[error("Unable to create socket: {0}")]
SocketCreate(IoError),
#[error("Unable to bind socket to local address {0}: {1}")]
BindAddr(SocketAddr, IoError),
#[error("Unable to listen on socket {0}: {1}")]
Listen(SocketAddr, IoError),
#[error("Unable to perform conversion into std listener: {0}")]
ConvertListener(IoError),
#[error("Unable to extract local address: {0}")]
ListenerLocalAddr(hyper::Error),
#[error("Unable to get socket domain: {0}")]
GetDomain(IoError),
#[error("Unable to set SO_REUSEADDR: {0}")]
SetReuseAddr(IoError),
#[error("Unable to set SO_RCVBUF: {0}")]
SetRecvBuffer(IoError),
#[error("Unable to set SO_SNDBUF: {0}")]
SetSendBuffer(IoError),
#[error("Unable to set SO_KEEPALIVE: {0}")]
SetKeepAlive(IoError),
#[error("Unable to set IP_TOS/IPV6_TCLASS: {0}")]
SetIpTos(IoError),
#[error("Unable to set TCP_MAXSEG: {0}")]
SetTcpMss(IoError),
#[error("Unable to set TCP_NODELAY: {0}")]
SetNoDelay(IoError),
#[error("Neither HTTP/1 nor HTTP/2 are enabled")]
NoProtocolsEnabled,
#[error(transparent)]
SignalHandler(#[from] SignalError),
#[error("TLS configuration error: {0}")]
TlsConfig(IoError),
#[error("No TLS configuration was provided")]
NoTlsConfig,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct ServerBuilder {
#[serde(default = "ServerBuilder::default_listen")]
pub listen: String,
#[serde(default)]
pub sleep_on_accept_errors: bool,
#[serde(default)]
pub ip: IpConfig,
#[serde(default)]
pub tcp: TcpConfig,
#[serde(default)]
pub http1: Http1Config,
#[serde(default)]
pub http2: Http2Config,
#[serde(default)]
pub tls: Option<TlsConfig>,
}
impl Default for ServerBuilder {
fn default() -> Self {
Self {
listen: Self::default_listen(),
sleep_on_accept_errors: false,
ip: IpConfig::default(),
tcp: TcpConfig::default(),
http1: Http1Config::default(),
http2: Http2Config::default(),
tls: None,
}
}
}
impl ServerBuilder {
#[must_use]
#[inline]
fn default_listen() -> String {
"localhost:8080".into()
}
#[must_use]
#[inline]
pub fn has_tls_config(&self) -> bool {
self.tls.is_some()
}
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn build(self) -> Result<axum_server::Server, ServerBuilderError> {
let span = debug_span!("build_server");
async move {
let listener = self.create_listener(&self.listen).await?;
let mut server = axum_server::from_tcp(listener);
let builder = server.http_builder();
self.configure_http1(builder);
self.configure_http2(builder);
info!("finished building plain server");
Ok(server)
}
.instrument(span)
.await
}
pub async fn build_tls(
self,
) -> Result<axum_server::Server<RustlsAcceptor>, ServerBuilderError> {
let span = debug_span!("build_tls_server");
async move {
let tls_config = self.tls.as_ref().ok_or(ServerBuilderError::NoTlsConfig)?;
let listener = self.create_listener(&tls_config.listen).await?;
let rustls_config = tls_config.rustls_config().await?;
let mut server = axum_server::from_tcp_rustls(listener, rustls_config);
let builder = server.http_builder();
self.configure_http1(builder);
self.configure_http2(builder);
info!("finished building TLS server");
Ok(server)
}
.instrument(span)
.await
}
pub async fn create_listener<O>(&self, addr_conf: O) -> Result<TcpListener, ServerBuilderError>
where
O: ToSocketAddrs + ToString,
{
let (sock, addr) = socket(addr_conf).await?;
let sref = SockRef::from(&sock);
let domain = sref
.domain()
.map_err(|err| ServerBuilderError::GetDomain(err.into()))?;
if let Some(tos) = self.ip.tos {
match domain {
socket2::Domain::IPV4 => sref.set_tos_v4(tos),
socket2::Domain::IPV6 => sref.set_tclass_v6(tos),
_ => Ok(()),
}
.map_err(|err| ServerBuilderError::SetIpTos(err.into()))?;
}
if let Some(sz) = self.tcp.recv_buffer {
sref.set_recv_buffer_size(sz.get())
.map_err(|err| ServerBuilderError::SetRecvBuffer(err.into()))?;
}
if let Some(sz) = self.tcp.send_buffer {
sref.set_send_buffer_size(sz.get())
.map_err(|err| ServerBuilderError::SetSendBuffer(err.into()))?;
}
if let Some(mss) = self.tcp.mss {
sref.set_tcp_mss(mss.get())
.map_err(|err| ServerBuilderError::SetTcpMss(err.into()))?;
}
if let Some(idle) = self.tcp.keepalive.idle {
let mut tcp_keepalive = socket2::TcpKeepalive::new().with_time(idle);
if let Some(interval) = self.tcp.keepalive.interval {
tcp_keepalive = tcp_keepalive.with_interval(interval);
}
if let Some(retries) = self.tcp.keepalive.retries {
tcp_keepalive = tcp_keepalive.with_retries(retries.get());
}
sref.set_tcp_keepalive(&tcp_keepalive)
.map_err(|err| ServerBuilderError::SetKeepAlive(err.into()))?;
} else {
sref.set_keepalive(false)
.map_err(|err| ServerBuilderError::SetKeepAlive(err.into()))?;
}
sock.bind(addr)
.map_err(|err| ServerBuilderError::BindAddr(addr, err.into()))?;
sock.set_nodelay(self.tcp.nodelay)
.map_err(|err| ServerBuilderError::SetNoDelay(err.into()))?;
sock.listen(self.tcp.backlog.get())
.map_err(|err| ServerBuilderError::Listen(addr, err.into()))?
.into_std()
.map_err(|err| ServerBuilderError::ConvertListener(err.into()))
}
pub fn configure_http1<E>(&self, builder: &mut Builder<E>) {
debug!("setting up HTTP/1");
let mut http1 = builder.http1();
http1
.half_close(self.http1.half_close)
.keep_alive(self.http1.keepalive);
if let Some(timeout) = self.http1.header_read_timeout {
http1.header_read_timeout(timeout);
}
if let Some(bufsz) = self.http1.max_buf_size {
http1.max_buf_size(bufsz.get());
}
if let Some(writev) = self.http1.writev {
http1.writev(writev);
}
}
pub fn configure_http2<E>(&self, builder: &mut Builder<E>) {
debug!("setting up HTTP/2");
let mut http2 = builder.http2();
http2
.adaptive_window(self.http2.adaptive_window)
.initial_connection_window_size(
self.http2.initial_connection_window.map(NonZeroU32::get),
)
.initial_stream_window_size(self.http2.initial_stream_window.map(NonZeroU32::get))
.keep_alive_interval(self.http2.keepalive.interval)
.max_concurrent_streams(self.http2.max_concurrent_streams.map(NonZeroU32::get));
if self.http2.connect_protocol {
http2.enable_connect_protocol();
}
if let Some(timeout) = self.http2.keepalive.timeout {
http2.keep_alive_timeout(timeout);
}
}
pub fn spawn_signal_handler(
&self,
handle: Handle,
) -> Result<JoinHandle<()>, ServerBuilderError> {
let span = debug_span!("signal_handler");
let mut sig = SignalStream::new()?;
Ok(tokio::spawn(
async move {
loop {
match sig.next().await {
Ok(sig) if sig.is_shutdown() => {
info!("received {}, shutting down server", sig.name());
handle.graceful_shutdown(Some(Duration::from_secs(5)));
break;
}
Ok(sig) => {
debug!("don't know what to do with signal {}, ignoring", sig.name());
}
Err(err) => {
error!("error in signal handler: {err}");
}
}
}
}
.instrument(span),
))
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct IpConfig {
#[serde(default)]
pub tos: Option<u32>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct TcpConfig {
#[serde(default = "crate::util::default_true")]
pub nodelay: bool,
#[serde(default)]
pub recv_buffer: Option<NonZeroUsize>,
#[serde(default)]
pub send_buffer: Option<NonZeroUsize>,
#[serde(default = "TcpConfig::default_backlog")]
pub backlog: NonZeroU32,
#[serde(default)]
pub mss: Option<NonZeroU32>,
#[serde(default)]
pub keepalive: TcpKeepaliveConfig,
}
impl Default for TcpConfig {
fn default() -> Self {
Self {
nodelay: true,
recv_buffer: None,
send_buffer: None,
backlog: Self::default_backlog(),
mss: None,
keepalive: TcpKeepaliveConfig::default(),
}
}
}
impl TcpConfig {
#[must_use]
#[inline]
#[allow(clippy::unwrap_used)]
fn default_backlog() -> NonZeroU32 {
NonZeroU32::new(1024).unwrap()
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct TcpKeepaliveConfig {
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub idle: Option<Duration>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub interval: Option<Duration>,
pub retries: Option<NonZeroU32>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct Http1Config {
#[serde(default)]
pub half_close: bool,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub header_read_timeout: Option<Duration>,
#[serde(default = "crate::util::default_true")]
pub keepalive: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_buf_size: Option<NonZeroUsize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub writev: Option<bool>,
}
impl Default for Http1Config {
fn default() -> Self {
Self {
half_close: false,
header_read_timeout: None,
keepalive: true,
max_buf_size: None,
writev: None,
}
}
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct Http2Config {
#[serde(default)]
pub adaptive_window: bool,
#[serde(default)]
pub connect_protocol: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub initial_connection_window: Option<NonZeroU32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub initial_stream_window: Option<NonZeroU32>,
#[serde(default)]
pub keepalive: Http2KeepaliveConfig,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrent_streams: Option<NonZeroU32>,
}
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct Http2KeepaliveConfig {
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub interval: Option<Duration>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub timeout: Option<Duration>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct TlsConfig {
#[serde(default = "TlsConfig::default_listen")]
pub listen: String,
#[serde(alias = "cert", alias = "chain")]
certificate: Box<Path>,
#[serde(alias = "key")]
private_key: Box<Path>,
}
impl TlsConfig {
#[must_use]
#[inline]
fn default_listen() -> String {
"localhost:8443".into()
}
pub async fn rustls_config(&self) -> Result<RustlsConfig, ServerBuilderError> {
RustlsConfig::from_pem_chain_file(&self.certificate, &self.private_key)
.await
.map_err(|err| ServerBuilderError::TlsConfig(err.into()))
}
}
async fn socket<O>(origin: O) -> Result<(TcpSocket, SocketAddr), ServerBuilderError>
where
O: ToSocketAddrs + ToString,
{
let mut ret_err = None;
let ret = resolve(&origin)
.await?
.find_map(|addr| match sock_create(&addr) {
Ok(sock) => Some((sock, addr)),
Err(err) => {
ret_err = Some(err);
None
}
});
match ret {
Some(pair) => Ok(pair),
None => match ret_err {
Some(err) => Err(err),
None => Err(ServerBuilderError::Resolve(origin.to_string())),
},
}
}
async fn resolve<O>(origin: &O) -> Result<impl Iterator<Item = SocketAddr> + '_, ServerBuilderError>
where
O: ToSocketAddrs + ToString,
{
lookup_host(origin)
.await
.map_err(|err| ServerBuilderError::AddressParse(err.into()))
}
fn sock_create(addr: &SocketAddr) -> Result<TcpSocket, ServerBuilderError> {
let socket = match addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}
.map_err(|err| ServerBuilderError::SocketCreate(err.into()))?;
#[cfg(not(windows))]
socket
.set_reuseaddr(true)
.map_err(|err| ServerBuilderError::SetReuseAddr(err.into()))?;
Ok(socket)
}