#![allow(clippy::needless_continue)]
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::Instrument as _;
use crate::admin::cluster_info::{format_text, ClusterInfoSnapshot};
use crate::stats::prometheus::render_prometheus;
use crate::stats::snapshot::Snapshot;
pub type ClusterInfoProvider = Arc<dyn Fn() -> ClusterInfoSnapshot + Send + Sync>;
pub const MAX_REQUEST_BYTES: usize = 8 * 1024;
pub const MAX_HEADERS: usize = 32;
const READ_TIMEOUT: Duration = Duration::from_secs(5);
pub struct StatsServer {
listener: TcpListener,
source: Arc<Mutex<Snapshot>>,
cluster_info: Option<ClusterInfoProvider>,
}
impl StatsServer {
pub async fn bind(addr: SocketAddr, source: Arc<Mutex<Snapshot>>) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
Ok(Self {
listener,
source,
cluster_info: None,
})
}
#[must_use]
pub fn with_cluster_info_provider(mut self, provider: ClusterInfoProvider) -> Self {
self.cluster_info = Some(provider);
self
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub async fn accept_one(&self) -> io::Result<()> {
let (sock, _peer) = self.listener.accept().await?;
let snapshot = self.source.lock().clone();
let cluster_info = self.cluster_info.clone();
serve_connection(sock, snapshot, cluster_info).await
}
pub async fn run(self) -> io::Result<()> {
let span = tracing::info_span!(
"stats_server.run",
local = %self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
);
let listener = self.listener;
let source = self.source;
let cluster_info = self.cluster_info;
async move {
loop {
let (sock, _peer) = listener.accept().await?;
let snapshot = source.lock().clone();
let ci = cluster_info.clone();
tokio::spawn(async move {
let _ = serve_connection(sock, snapshot, ci).await;
});
}
}
.instrument(span)
.await
}
}
async fn serve_connection(
mut sock: TcpStream,
snapshot: Snapshot,
cluster_info: Option<ClusterInfoProvider>,
) -> io::Result<()> {
let mut buf = vec![0u8; MAX_REQUEST_BYTES];
let mut filled = 0usize;
loop {
if filled == buf.len() {
return write_response(&mut sock, 400, "Bad Request", b"").await;
}
let read_result = tokio::time::timeout(READ_TIMEOUT, sock.read(&mut buf[filled..])).await;
let Ok(Ok(n)) = read_result else {
let _ = sock.shutdown().await;
return Ok(());
};
if n == 0 {
break;
}
filled += n;
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut headers);
match req.parse(&buf[..filled]) {
Ok(httparse::Status::Complete(_)) => {
return handle_parsed(&mut sock, &req, snapshot, cluster_info).await;
}
Ok(httparse::Status::Partial) => continue,
Err(_) => {
return write_response(&mut sock, 400, "Bad Request", b"").await;
}
}
}
Ok(())
}
async fn handle_parsed(
sock: &mut TcpStream,
req: &httparse::Request<'_, '_>,
snapshot: Snapshot,
cluster_info: Option<ClusterInfoProvider>,
) -> io::Result<()> {
let path = req.path.unwrap_or("/");
if !matches!(req.method, Some("GET")) {
return write_response(sock, 405, "Method Not Allowed", b"").await;
}
match path {
"/" | "/info" | "/stats" => {
let body = snapshot.to_json();
write_json_response(sock, body.as_bytes()).await
}
"/metrics" => {
let body = render_prometheus(&snapshot);
write_metrics_response(sock, body.as_bytes()).await
}
"/cluster-info.txt" => match cluster_info {
Some(provider) => {
let snap = provider();
let mut body: Vec<u8> = Vec::with_capacity(4096);
if format_text(&snap, &mut body).is_err() {
return write_response(sock, 500, "Internal Server Error", b"").await;
}
write_text_response(sock, &body).await
}
None => write_response(sock, 503, "Service Unavailable", b"").await,
},
_ => write_response(sock, 200, "OK", b"OK\r\n").await,
}
}
async fn write_text_response(sock: &mut TcpStream, body: &[u8]) -> io::Result<()> {
let header = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=us-ascii\r\n\
Content-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
sock.write_all(header.as_bytes()).await?;
sock.write_all(body).await?;
sock.shutdown().await?;
Ok(())
}
async fn write_response(
sock: &mut TcpStream,
code: u16,
reason: &str,
body: &[u8],
) -> io::Result<()> {
let header = format!(
"HTTP/1.1 {code} {reason}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
sock.write_all(header.as_bytes()).await?;
if !body.is_empty() {
sock.write_all(body).await?;
}
sock.shutdown().await?;
Ok(())
}
async fn write_json_response(sock: &mut TcpStream, body: &[u8]) -> io::Result<()> {
let header = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json; charset=utf-8\r\n\
Content-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
sock.write_all(header.as_bytes()).await?;
sock.write_all(body).await?;
sock.shutdown().await?;
Ok(())
}
async fn write_metrics_response(sock: &mut TcpStream, body: &[u8]) -> io::Result<()> {
let header = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/plain; version=0.0.4; charset=utf-8\r\n\
Content-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
sock.write_all(header.as_bytes()).await?;
sock.write_all(body).await?;
sock.shutdown().await?;
Ok(())
}