use super::settings::TelemetrySettings;
use crate::BootstrapResult;
use crate::addr::ListenAddr;
use crate::telemetry::log;
use anyhow::Context as _;
use futures_util::future::FutureExt;
use futures_util::{pin_mut, ready};
use hyper_util::rt::TokioIo;
use socket2::{Domain, SockAddr, Socket, Type};
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::{TcpStream, UnixListener, UnixStream};
use tokio::sync::watch;
mod router;
#[cfg(feature = "memory-profiling")]
mod pprof_symbol;
use router::Router;
enum TelemetryStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl AsyncRead for TelemetryStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TelemetryStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
enum TelemetryListener {
Tcp(TcpListener),
#[cfg(unix)]
Unix(UnixListener),
}
impl TelemetryListener {
pub(crate) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
match self {
TelemetryListener::Tcp(listener) => Ok(listener.local_addr()?.into()),
#[cfg(unix)]
TelemetryListener::Unix(listener) => match listener.local_addr()?.as_pathname() {
Some(path) => Ok(path.to_path_buf().into()),
None => Err(anyhow::anyhow!("unix socket listener has no pathname")),
},
}
}
pub(crate) async fn accept(&self) -> std::io::Result<TelemetryStream> {
match self {
TelemetryListener::Tcp(listener) => listener
.accept()
.await
.map(|(conn, _)| TelemetryStream::Tcp(conn)),
#[cfg(unix)]
TelemetryListener::Unix(listener) => listener
.accept()
.await
.map(|(conn, _)| TelemetryStream::Unix(conn)),
}
}
pub(crate) fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<TelemetryStream>> {
match self {
TelemetryListener::Tcp(listener) => match std::task::ready!(listener.poll_accept(cx)) {
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Tcp(conn))),
Err(e) => std::task::Poll::Ready(Err(e)),
},
#[cfg(unix)]
TelemetryListener::Unix(listener) => {
match std::task::ready!(listener.poll_accept(cx)) {
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Unix(conn))),
Err(e) => std::task::Poll::Ready(Err(e)),
}
}
}
}
}
pub use router::{
TelemetryRouteBody, TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute,
};
pub(super) struct TelemetryServerFuture {
listener: TelemetryListener,
router: Router,
}
impl TelemetryServerFuture {
pub(super) fn new(
settings: TelemetrySettings,
custom_routes: Vec<TelemetryServerRoute>,
) -> BootstrapResult<Option<TelemetryServerFuture>> {
if !settings.server.enabled {
return Ok(None);
}
let settings = Arc::new(settings);
#[cfg(all(target_os = "linux", feature = "memory-profiling"))]
if settings.memory_profiler.enabled {
memory_profiling::profiler(Arc::clone(&settings))
.map_err(|err| anyhow::anyhow!(err))?;
}
let router = Router::new(custom_routes, Arc::clone(&settings));
let listener = match &settings.server.addr {
ListenAddr::Tcp(addr) => {
let std_listener = std::net::TcpListener::from(
bind_socket(*addr)
.with_context(|| format!("binding to TCP socket {addr:?}"))?,
);
std_listener.set_nonblocking(true)?;
let tokio_listener = tokio::net::TcpListener::from_std(std_listener)?;
TelemetryListener::Tcp(tokio_listener)
}
#[cfg(unix)]
ListenAddr::Unix(path) => {
if path.exists()
&& let Err(e) = std::fs::remove_file(path)
{
log::warn!("failed to remove existing Unix socket file"; "path" => %path.display(), "error" => e);
}
let unix_listener = UnixListener::bind(path)
.with_context(|| format!("binding to Unix socket {path:?}"))?;
TelemetryListener::Unix(unix_listener)
}
};
Ok(Some(TelemetryServerFuture { listener, router }))
}
pub(super) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
self.listener.local_addr()
}
pub(super) async fn with_graceful_shutdown(
self,
shutdown_signal: impl Future<Output = ()> + Send + Sync + 'static,
) {
let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
shutdown_signal.await;
drop(signal_rx);
});
let (close_tx, close_rx) = watch::channel(());
let listener = self.listener;
loop {
let socket = tokio::select! {
conn = listener.accept() => match conn {
Ok(conn) => TokioIo::new(conn),
Err(e) => {
log::warn!("failed to accept connection"; "error" => e);
continue;
}
},
_ = signal_tx.closed() => { break },
};
let router = self.router.clone();
let signal_tx = Arc::clone(&signal_tx);
let close_rx = close_rx.clone();
tokio::spawn(async move {
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(socket, router)
.with_upgrades();
let signal_closed = signal_tx.closed().fuse();
pin_mut!(conn);
pin_mut!(signal_closed);
loop {
tokio::select! {
_ = conn.as_mut() => break,
_ = &mut signal_closed => conn.as_mut().graceful_shutdown(),
}
}
drop(close_rx);
});
}
drop(close_rx);
close_tx.closed().await;
}
}
impl Future for TelemetryServerFuture {
type Output = Infallible;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = &mut *self;
loop {
let socket = match ready!(this.listener.poll_accept(cx)) {
Ok(conn) => TokioIo::new(conn),
Err(e) => {
log::warn!("failed to accept connection"; "error" => e);
continue;
}
};
let router = this.router.clone();
tokio::spawn(
hyper::server::conn::http1::Builder::new()
.serve_connection(socket, router)
.with_upgrades(),
);
}
}
}
fn bind_socket(addr: SocketAddr) -> BootstrapResult<Socket> {
let socket = Socket::new(
if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
},
Type::STREAM,
None,
)?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.bind(&SockAddr::from(addr))?;
socket.listen(1024)?;
Ok(socket)
}
#[cfg(all(target_os = "linux", feature = "memory-profiling"))]
mod memory_profiling {
use super::*;
use crate::Result;
use crate::telemetry::MemoryProfiler;
pub(super) fn profiler(settings: Arc<TelemetrySettings>) -> Result<MemoryProfiler> {
MemoryProfiler::get_or_init_with(&settings.memory_profiler)?.ok_or_else(|| {
"profiling should be enabled via `_RJEM_MALLOC_CONF=prof:true` env var".into()
})
}
pub(super) async fn heap_profile(settings: Arc<TelemetrySettings>) -> Result<String> {
profiler(settings)?.heap_profile().await
}
pub(super) async fn heap_stats(settings: Arc<TelemetrySettings>) -> Result<String> {
profiler(settings)?.heap_stats()
}
}