pub mod proxy_protocol;
use std::{
fmt::Display,
io,
net::SocketAddr,
os::unix::fs::PermissionsExt,
path::PathBuf,
sync::{
Arc,
atomic::{AtomicU32, AtomicU64, Ordering},
},
time::{Duration, Instant},
};
use anyhow::{Context, anyhow};
use async_trait::async_trait;
use axum::{Router, extract::Request};
use http::Response;
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo, TokioTimer},
server::conn::auto::Builder,
};
use ic_bn_lib_common::{
traits::Run,
types::{
http::{
ALPN_ACME, Addr, ConnInfo, Error, ListenerOpts, Metrics, ProxyProtocolMode,
ServerOptions, TlsInfo,
},
tls::TlsOptions,
},
};
use prometheus::{
Registry,
core::{AtomicI64, GenericGauge},
};
use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
use rustls::sign::SingleCertAndKey;
use scopeguard::defer;
use socket2::{Domain, Socket, Type};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, UnixListener, UnixSocket},
pin, select,
sync::mpsc::channel,
time::{sleep, timeout},
};
use tokio_io_timeout::TimeoutStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tower_service::Service;
use tracing::{debug, info, warn};
use uuid::Uuid;
use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody};
use crate::tls::{pem_convert_to_rustls, prepare_server_config};
const YEAR: Duration = Duration::from_secs(86400 * 365);
pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}
impl Listener {
pub fn new(addr: Addr, opts: ListenerOpts) -> Result<Self, Error> {
Ok(match addr {
Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?),
Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?),
})
}
async fn accept(&self) -> Result<(Box<dyn AsyncReadWrite>, Addr), io::Error> {
Ok(match self {
Self::Tcp(v) => {
let x = v.accept().await?;
(Box::new(x.0), Addr::Tcp(x.1))
}
Self::Unix(v) => {
let x = v.accept().await?;
(
Box::new(x.0),
Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()),
)
}
})
}
pub fn local_addr(&self) -> Option<SocketAddr> {
match &self {
Self::Tcp(v) => v.local_addr().ok(),
Self::Unix(_) => None,
}
}
}
impl From<TcpListener> for Listener {
fn from(v: TcpListener) -> Self {
Self::Tcp(v)
}
}
impl From<UnixListener> for Listener {
fn from(v: UnixListener) -> Self {
Self::Unix(v)
}
}
#[derive(Clone)]
enum RequestState {
Start,
End,
}
async fn tls_handshake(
rustls_cfg: Arc<rustls::ServerConfig>,
stream: impl AsyncReadWrite,
) -> Result<(impl AsyncReadWrite, TlsInfo), Error> {
let tls_acceptor = TlsAcceptor::from(rustls_cfg);
let start = Instant::now();
let stream = tls_acceptor
.accept(stream)
.await
.context("TLS accept failed")?;
let duration = start.elapsed();
let conn = stream.get_ref().1;
let mut tls_info = TlsInfo::try_from(conn)?;
tls_info.handshake_dur = duration;
Ok((stream, tls_info))
}
struct Conn {
addr: Addr,
remote_addr: Addr,
router: Router,
builder: Builder<TokioExecutor>,
token_graceful: CancellationToken,
token_forceful: CancellationToken,
options: ServerOptions,
metrics: Metrics,
requests: AtomicU32,
rustls_cfg: Option<Arc<rustls::ServerConfig>>,
}
impl Display for Conn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] <- [{}]", self.addr, self.remote_addr)
}
}
impl Conn {
async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
let accepted_at = Instant::now();
debug!("{self}: got a new connection");
let addr = self.addr.to_string();
let labels = &mut [
addr.as_str(), self.remote_addr.family(), "no", "no", "no", "no", ];
let (stream, stats) = AsyncCounter::new(stream);
let (stream, proxy_hdr): (Box<dyn AsyncReadWrite>, Option<ProxyHeader>) =
if self.options.proxy_protocol_mode != ProxyProtocolMode::Off {
let (stream, hdr) = ProxyProtocolStream::accept(stream)
.await
.context("unable to accept Proxy Protocol")?;
if self.options.proxy_protocol_mode == ProxyProtocolMode::Forced && hdr.is_none() {
return Err(Error::NoProxyProtocolDetected);
}
(Box::new(stream), hdr)
} else {
(Box::new(stream), None)
};
let (local_addr, remote_addr) = proxy_hdr
.map(|x| (Addr::Tcp(x.dst), Addr::Tcp(x.src)))
.unwrap_or_else(|| (self.addr.clone(), self.remote_addr.clone()));
let conn_info = Arc::new(ConnInfo {
id: Uuid::now_v7(),
accepted_at,
remote_addr,
local_addr,
traffic: stats.clone(),
req_count: AtomicU64::new(0),
close: self.token_forceful.clone(),
});
let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(rustls_cfg) =
&self.rustls_cfg
{
debug!("{}: performing TLS handshake", self);
let (mut stream_tls, tls_info) = timeout(
self.options.tls_handshake_timeout,
tls_handshake(rustls_cfg.clone(), stream),
)
.await
.context("TLS handshake timed out")?
.context("TLS handshake failed")?;
debug!(
"{}: handshake finished in {}ms (SNI: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})",
self,
tls_info.handshake_dur.as_millis(),
tls_info.sni,
tls_info.protocol,
tls_info.cipher,
tls_info.alpn,
);
if tls_info
.alpn
.as_ref()
.is_some_and(|x| x.as_bytes() == ALPN_ACME)
{
debug!("{self}: ACME ALPN - closing connection");
timeout(Duration::from_secs(5), stream_tls.shutdown())
.await
.context("socket shutdown timed out")?
.context("socket shutdown failed")?;
return Ok(());
}
(Box::new(stream_tls), Some(Arc::new(tls_info)))
} else {
(Box::new(stream), None)
};
if let Some(v) = &tls_info {
labels[2] = v.protocol.as_str().unwrap();
labels[3] = v.cipher.as_str().unwrap();
self.metrics
.conn_tls_handshake_duration
.with_label_values(&labels[0..4])
.observe(v.handshake_dur.as_secs_f64());
}
self.metrics
.conns_open
.with_label_values(&labels[0..4])
.inc();
let requests_inflight = self
.metrics
.requests_inflight
.with_label_values(&labels[0..4]);
let result = self
.handle_inner(stream, conn_info.clone(), tls_info, requests_inflight)
.await;
let (sent, rcvd) = (stats.sent(), stats.rcvd());
let dur = accepted_at.elapsed().as_secs_f64();
let reqs = conn_info.req_count.load(Ordering::SeqCst);
if self.token_forceful.is_cancelled() {
labels[4] = "yes";
}
if self.token_graceful.is_cancelled() {
labels[5] = "yes";
}
self.metrics.conns.with_label_values(labels).inc();
self.metrics
.conns_open
.with_label_values(&labels[0..4])
.dec();
self.metrics.requests.with_label_values(labels).inc_by(reqs);
self.metrics
.bytes_rcvd
.with_label_values(labels)
.inc_by(rcvd);
self.metrics
.bytes_sent
.with_label_values(labels)
.inc_by(sent);
self.metrics
.conn_duration
.with_label_values(labels)
.observe(dur);
self.metrics
.requests_per_conn
.with_label_values(labels)
.observe(reqs as f64);
debug!(
"{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})",
self.token_graceful.is_cancelled(),
self.token_forceful.is_cancelled(),
);
result
}
async fn handle_inner(
&self,
stream: Box<dyn AsyncReadWrite>,
conn_info: Arc<ConnInfo>,
tls_info: Option<Arc<TlsInfo>>,
requests_inflight: GenericGauge<AtomicI64>,
) -> Result<(), Error> {
let mut idle_timer = Box::pin(sleep(self.options.idle_timeout.unwrap_or(10 * YEAR)));
let (state_tx, mut state_rx) = channel(65536);
let mut stream = TimeoutStream::new(stream);
stream.set_read_timeout(self.options.read_timeout);
stream.set_write_timeout(self.options.write_timeout);
let stream = TokioIo::new(stream);
let max_requests_per_conn = self.options.max_requests_per_conn;
let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
let _ = state_tx.try_send(RequestState::Start);
request.extensions_mut().insert(conn_info.clone());
if let Some(v) = &tls_info {
request.extensions_mut().insert(v.clone());
}
let mut router = self.router.clone();
let token = self.token_graceful.clone();
let conn_info = conn_info.clone();
let state_tx = state_tx.clone();
let requests_inflight = requests_inflight.clone();
async move {
requests_inflight.inc();
defer! {
requests_inflight.dec();
}
let result = router.call(request).await.map(|x| {
let (parts, body) = x.into_parts();
let body = NotifyingBody::new(body, state_tx, RequestState::End);
Response::from_parts(parts, body)
});
if let Some(v) = max_requests_per_conn {
let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst);
if req_count + 1 >= v {
token.cancel();
}
}
result
}
});
let conn = self
.builder
.serve_connection_with_upgrades(Box::pin(stream), service);
pin!(conn);
loop {
select! {
biased;
() = self.token_forceful.cancelled() => {
break;
}
() = self.token_graceful.cancelled() => {
conn.as_mut().graceful_shutdown();
let _ = timeout(self.options.grace_period, conn.as_mut()).await;
break;
},
Some(v) = state_rx.recv() => {
match v {
RequestState::Start => {
let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1;
debug!("{self}: request started");
if self.options.idle_timeout.is_some() {
debug!("{self}: stopping idle timer (now: {reqs})");
idle_timer.as_mut().reset(tokio::time::Instant::now() + 10 * YEAR);
}
},
RequestState::End => {
let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1;
debug!("{self}: request finished (now: {reqs})");
if let Some(v) = self.options.idle_timeout && reqs == 0 {
debug!("{self}: no outstanding requests, starting timer");
idle_timer.as_mut().reset(tokio::time::Instant::now() + v);
}
}
}
},
() = idle_timer.as_mut(), if self.options.idle_timeout.is_some() => {
debug!("{self}: Idle timeout triggered, closing");
conn.as_mut().graceful_shutdown();
let _ = timeout(Duration::from_secs(5), conn.as_mut()).await;
break;
},
v = conn.as_mut() => {
if let Err(e) = v {
return Err(anyhow!("unable to serve connection: {e:#}").into());
}
break;
},
}
}
Ok(())
}
}
pub struct ServerBuilder {
addr: Option<Addr>,
router: Router,
registry: Registry,
metrics: Option<Metrics>,
options: ServerOptions,
rustls_cfg: Option<rustls::ServerConfig>,
}
impl ServerBuilder {
pub fn new(router: Router) -> Self {
Self {
addr: None,
router,
registry: Registry::new(),
metrics: None,
options: ServerOptions::default(),
rustls_cfg: None,
}
}
pub fn listen_tcp(mut self, socket: SocketAddr) -> Self {
self.addr = Some(Addr::Tcp(socket));
self
}
pub fn listen_unix(mut self, path: PathBuf) -> Self {
self.addr = Some(Addr::Unix(path));
self
}
pub fn with_metrics_registry(mut self, registry: &Registry) -> Self {
self.registry = registry.clone();
self
}
pub fn with_metrics(mut self, metrics: Metrics) -> Self {
self.metrics = Some(metrics);
self
}
pub fn with_rustls_config(mut self, rustls_cfg: rustls::ServerConfig) -> Self {
self.rustls_cfg = Some(rustls_cfg);
self
}
pub const fn with_options(mut self, options: ServerOptions) -> Self {
self.options = options;
self
}
pub fn with_rustls_single_cert(mut self, cert: PathBuf, key: PathBuf) -> Result<Self, Error> {
let cert = std::fs::read(cert).context("unable to read cert")?;
let key = std::fs::read(key).context("unable to read key")?;
let cert = pem_convert_to_rustls(&key, &cert).context("unable to parse cert+key pair")?;
let resolver = SingleCertAndKey::from(cert);
let tls_opts = TlsOptions::default();
let rustls_cfg = prepare_server_config(tls_opts, Arc::new(resolver), &self.registry);
self.rustls_cfg = Some(rustls_cfg);
Ok(self)
}
pub fn build(self) -> Result<Server, Error> {
let Some(addr) = self.addr else {
return Err(Error::Generic(anyhow!("Listening address not specified")));
};
let metrics = self.metrics.unwrap_or_else(|| Metrics::new(&self.registry));
Ok(Server::new(
addr,
self.router,
self.options,
metrics,
self.rustls_cfg,
))
}
}
pub struct Server {
addr: Addr,
router: Router,
tracker: TaskTracker,
options: ServerOptions,
metrics: Metrics,
builder: Builder<TokioExecutor>,
rustls_cfg: Option<Arc<rustls::ServerConfig>>,
}
impl Display for Server {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}]", self.addr)
}
}
impl Server {
pub fn new(
addr: Addr,
router: Router,
options: ServerOptions,
metrics: Metrics,
rustls_cfg: Option<rustls::ServerConfig>,
) -> Self {
let mut builder = Builder::new(TokioExecutor::new());
builder
.http1()
.timer(TokioTimer::new()) .header_read_timeout(Some(options.http1_header_read_timeout))
.keep_alive(true)
.http2()
.adaptive_window(true)
.max_concurrent_streams(Some(options.http2_max_streams))
.timer(TokioTimer::new()) .keep_alive_interval(options.http2_keepalive_interval)
.keep_alive_timeout(options.http2_keepalive_timeout)
.enable_connect_protocol();
Self {
addr,
router,
options,
metrics,
tracker: TaskTracker::new(),
builder,
rustls_cfg: rustls_cfg.map(Arc::new),
}
}
pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
let opts = ListenerOpts {
backlog: self.options.backlog,
mss: self.options.tcp_mss,
keepalive: (&self.options).into(),
};
let listener = Listener::new(self.addr.clone(), opts)?;
self.serve_with_listener(listener, token).await
}
fn spawn_connection(
&self,
stream: Box<dyn AsyncReadWrite>,
remote_addr: Addr,
token: CancellationToken,
) {
let conn = Conn {
addr: self.addr.clone(),
remote_addr: remote_addr.clone(),
router: self.router.clone(),
builder: self.builder.clone(),
token_graceful: token,
token_forceful: CancellationToken::new(),
options: self.options,
metrics: self.metrics.clone(), requests: AtomicU32::new(0),
rustls_cfg: self.rustls_cfg.clone(),
};
self.tracker.spawn(async move {
if let Err(e) = conn.handle(stream).await {
info!(
"[{}] <- [{remote_addr}]: failed to handle connection: {e:#}",
conn.addr
);
}
debug!("[{}] <- [{remote_addr}]: connection finished", conn.addr);
});
}
pub async fn serve_with_listener(
&self,
listener: Listener,
token: CancellationToken,
) -> Result<(), Error> {
warn!("{self}: running (TLS: {})", self.rustls_cfg.is_some());
loop {
select! {
biased;
() = token.cancelled() => {
drop(listener);
warn!("{self}: shutting down, waiting for the active connections to close for {}s", self.options.grace_period.as_secs());
self.tracker.close();
select! {
() = sleep(self.options.grace_period + Duration::from_secs(5)) => {
warn!("{self}: connections didn't close in time, shutting down anyway");
},
() = self.tracker.wait() => {},
}
warn!("{self}: shut down");
if let Addr::Unix(v) = &self.addr {
let _ = std::fs::remove_file(v);
}
return Ok(());
},
v = listener.accept() => {
let (stream, remote_addr) = match v {
Ok(v) => v,
Err(e) => {
warn!("{self}: unable to accept connection: {e:#}");
sleep(Duration::from_millis(10)).await;
continue;
}
};
self.spawn_connection(stream, remote_addr, token.child_token());
}
}
}
}
}
pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result<TcpListener, Error> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?;
socket
.set_tcp_nodelay(true)
.context("unable to set TCP_NODELAY")?;
if let Some(v) = opts.mss {
socket.set_tcp_mss(v).context("unable to set TCP MSS")?;
}
socket
.set_reuse_address(true)
.context("unable to set SO_REUSEADDR")?;
socket
.set_tcp_keepalive(&opts.keepalive)
.context("unable to set keepalive on the socket")?;
socket
.set_nonblocking(true)
.context("unable to set socket into non-blocking mode")?;
socket.bind(&addr.into()).context("unable to bind socket")?;
socket
.listen(opts.backlog as i32)
.context("unable to listen on the socket")?;
let listener = TcpListener::from_std(socket.into())
.context("unable to convert socket from the standard one")?;
Ok(listener)
}
pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result<UnixListener, Error> {
let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?;
if path.exists() {
std::fs::remove_file(&path).context("unable to remove UNIX socket")?;
}
socket.bind(&path).context("unable to bind socket")?;
let socket = socket
.listen(opts.backlog)
.context("unable to listen socket")?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))
.context("unable to set permissions on socket")?;
Ok(socket)
}
#[async_trait]
impl Run for Server {
async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
self.serve(token).await?;
Ok(())
}
}
#[cfg(test)]
mod test {
use http::StatusCode;
use super::*;
#[tokio::test]
async fn test_server() {
let opts = ServerOptions::default();
let listener = listen_tcp(
"127.0.0.1:0".parse().unwrap(),
ListenerOpts {
backlog: 128,
mss: None,
keepalive: (&opts).into(),
},
)
.unwrap();
let addr = listener.local_addr().unwrap();
let server = Server::new(
Addr::Tcp(addr),
Router::new(),
opts,
Metrics::new(&Registry::new()),
None,
);
tokio::spawn(async move {
server
.serve_with_listener(listener.into(), CancellationToken::new())
.await
.unwrap();
});
for _ in 0..10 {
let Ok(result) = reqwest::get(format!("http://{addr}")).await else {
tokio::time::sleep(Duration::from_millis(10)).await;
continue;
};
assert_eq!(result.status(), StatusCode::NOT_FOUND);
break;
}
}
}