use ahash::AHashMap;
use futures_util::future;
use std::{
fmt,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tracing::{debug, info_span, Instrument};
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
mod diagnostics;
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
pub(crate) use self::diagnostics::Diagnostics;
#[cfg(all(
feature = "runtime",
feature = "runtime-diagnostics",
feature = "lease"
))]
pub(crate) use self::diagnostics::LeaseDiagnostics;
#[derive(Debug, thiserror::Error)]
#[error("failed to bind admin server: {0}")]
pub struct BindError(#[from] std::io::Error);
type Request = hyper::Request<hyper::body::Incoming>;
type Body = http_body_util::Full<bytes::Bytes>;
type Response = hyper::Response<Body>;
type HandlerFn = Box<dyn Fn(Request) -> Response + Send + Sync + 'static>;
#[cfg(feature = "prometheus-client")]
mod metrics;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
#[cfg_attr(docsrs, doc(cfg(feature = "admin")))]
pub struct AdminArgs {
#[cfg_attr(feature = "clap", clap(long, default_value = "0.0.0.0:8080"))]
pub admin_addr: SocketAddr,
}
#[cfg_attr(docsrs, doc(cfg(feature = "admin")))]
pub struct Builder {
addr: SocketAddr,
ready: Readiness,
routes: AHashMap<String, HandlerFn>,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics: Diagnostics,
}
#[cfg_attr(docsrs, doc(cfg(feature = "admin")))]
pub struct Bound {
addr: SocketAddr,
ready: Readiness,
listener: tokio::net::TcpListener,
server: hyper::server::conn::http1::Builder,
routes: AHashMap<String, HandlerFn>,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics: Diagnostics,
}
#[cfg_attr(docsrs, doc(cfg(feature = "admin")))]
#[derive(Clone, Debug)]
pub struct Readiness(Arc<AtomicBool>);
#[cfg_attr(docsrs, doc(cfg(feature = "admin")))]
#[derive(Debug)]
pub struct Server {
addr: SocketAddr,
ready: Readiness,
task: tokio::task::JoinHandle<Result<(), hyper::Error>>,
}
impl Default for AdminArgs {
fn default() -> Self {
Self {
admin_addr: SocketAddr::from(([0, 0, 0, 0], 8080)),
}
}
}
impl AdminArgs {
pub fn into_builder(self) -> Builder {
Builder::new(self.admin_addr)
}
}
impl Default for Builder {
fn default() -> Self {
AdminArgs::default().into_builder()
}
}
impl From<AdminArgs> for Builder {
fn from(args: AdminArgs) -> Self {
args.into_builder()
}
}
impl Builder {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
ready: Readiness(Arc::new(false.into())),
routes: Default::default(),
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics: Diagnostics::new(),
}
}
pub fn readiness(&self) -> Readiness {
self.ready.clone()
}
pub fn set_ready(&self) {
self.ready.set(true);
}
#[cfg(feature = "prometheus-client")]
#[cfg_attr(docsrs, doc(cfg(feature = "prometheus-client")))]
pub fn with_prometheus(self, mut registry: prometheus_client::registry::Registry) -> Self {
#[cfg(not(tokio_unstable))]
tracing::debug!("Tokio runtime metrics cannot be monitored without the tokio_unstable cfg");
#[cfg(tokio_unstable)]
{
let metrics = kubert_prometheus_tokio::Runtime::register(
registry.sub_registry_with_prefix("tokio_rt"),
tokio::runtime::Handle::current(),
);
let mut interval = tokio::time::interval(Duration::from_secs(1));
tokio::spawn(
async move { metrics.updated(&mut interval).await }
.instrument(tracing::info_span!("kubert-prom-tokio-rt")),
);
}
if let Err(error) =
kubert_prometheus_process::register(registry.sub_registry_with_prefix("process"))
{
tracing::warn!(%error, "Process metrics cannot be monitored");
}
self.with_prometheus_handler("/metrics", registry)
}
#[cfg(feature = "prometheus-client")]
#[cfg_attr(docsrs, doc(cfg(feature = "prometheus-client")))]
pub fn with_prometheus_handler(
self,
path: impl ToString,
registry: prometheus_client::registry::Registry,
) -> Self {
let prom = metrics::Prometheus::new(registry);
self.with_handler(path, move |req| prom.handle_metrics(req))
}
pub fn with_handler(
mut self,
path: impl ToString,
handler: impl Fn(Request) -> Response + Send + Sync + 'static,
) -> Self {
let path = path.to_string();
assert_ne!(
path, "/ready",
"the built-in `/ready` handler cannot be overridden"
);
assert_ne!(
path, "/live",
"the built-in `/live` handler cannot be overridden"
);
self.routes.insert(path, Box::new(handler));
self
}
pub fn bind(self) -> Result<Bound, BindError> {
let Self {
addr,
ready,
routes,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics,
} = self;
let lis = std::net::TcpListener::bind(addr)?;
lis.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(lis)?;
let mut server = hyper::server::conn::http1::Builder::new();
server
.half_close(true)
.timer(hyper_util::rt::TokioTimer::default())
.header_read_timeout(Duration::from_secs(2))
.max_buf_size(8 * 1024);
Ok(Bound {
addr,
ready,
server,
listener,
routes,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics,
})
}
}
impl fmt::Debug for Builder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("Builder");
d.field("addr", &self.addr).field("ready", &self.ready);
d.finish()
}
}
impl Bound {
pub fn readiness(&self) -> Readiness {
self.ready.clone()
}
pub fn set_ready(&self) {
self.ready.set(true);
}
pub fn spawn(self) -> Server {
let Self {
ready,
server,
listener,
routes,
addr,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
diagnostics,
} = self;
let task = tokio::spawn({
let ready = ready.clone();
let routes = Arc::new(routes);
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
let diagnostics = diagnostics.clone();
async move {
loop {
let (stream, client_addr) = match listener.accept().await {
Ok(socket) => socket,
Err(error) => {
tracing::warn!(%error, "Failed to accept connection");
continue;
}
};
if let Err(error) = stream.set_nodelay(true) {
tracing::warn!(%error, "Failed to set TCP_NODELAY");
}
tracing::trace!(client.addr = ?client_addr, "Accepted connection");
let svc = {
use tower::ServiceExt;
let ready = ready.clone();
let routes = routes.clone();
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
let diagnostics = diagnostics.clone();
let svc = tower::service_fn(move |req: Request| {
handle(
&ready,
&routes,
req,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
(client_addr, &diagnostics),
)
});
#[cfg(any(feature = "admin-brotli", feature = "admin-gzip"))]
let svc = tower_http::compression::Compression::new(svc);
hyper::service::service_fn(move |req| svc.clone().oneshot(req))
};
let serve =
server.serve_connection(hyper_util::rt::TokioIo::new(stream), svc.clone());
tokio::spawn(
async move {
debug!("Serving");
serve.await
}
.instrument(
tracing::debug_span!("conn", client.addr = %client_addr).or_current(),
),
);
}
}
.instrument(info_span!("admin", port = %self.addr.port()))
});
Server { task, addr, ready }
}
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
pub(crate) fn diagnostics(&self) -> &Diagnostics {
&self.diagnostics
}
}
impl Readiness {
pub fn get(&self) -> bool {
self.0.load(Ordering::Acquire)
}
pub fn set(&self, ready: bool) {
self.0.store(ready, Ordering::Release);
}
}
impl Server {
pub fn local_addr(&self) -> SocketAddr {
self.addr
}
pub fn readiness(&self) -> Readiness {
self.ready.clone()
}
pub fn into_join_handle(self) -> tokio::task::JoinHandle<Result<(), hyper::Error>> {
self.task
}
}
fn handle(
ready: &Readiness,
routes: &Arc<AHashMap<String, HandlerFn>>,
req: Request,
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))] (client_addr, diagnostics): (
std::net::SocketAddr,
&Diagnostics,
),
) -> Pin<Box<dyn std::future::Future<Output = Result<Response, tokio::task::JoinError>> + Send>> {
if req.uri().path() == "/live" {
return Box::pin(future::ok(handle_live(req)));
}
if req.uri().path() == "/ready" {
return Box::pin(future::ok(handle_ready(ready, req)));
}
#[cfg(all(feature = "runtime", feature = "runtime-diagnostics"))]
if req.uri().path() == "/kubert.json" {
return Box::pin(future::ok(diagnostics.handle(client_addr, req)));
}
if routes.contains_key(req.uri().path()) {
let routes = routes.clone();
let path = req.uri().path().to_string();
return Box::pin(tokio::task::spawn_blocking(move || {
let handler = routes.get(&path).expect("routes must contain path");
handler(req)
}));
}
Box::pin(future::ok(
hyper::Response::builder()
.status(hyper::StatusCode::NOT_FOUND)
.body(Body::default())
.unwrap(),
))
}
fn handle_live(req: Request) -> Response {
match *req.method() {
hyper::Method::GET | hyper::Method::HEAD => hyper::Response::builder()
.status(hyper::StatusCode::OK)
.header(hyper::header::CONTENT_TYPE, "text/plain")
.body("alive\n".into())
.unwrap(),
_ => hyper::Response::builder()
.status(hyper::StatusCode::METHOD_NOT_ALLOWED)
.header(hyper::header::ALLOW, "GET, HEAD")
.body(Body::default())
.unwrap(),
}
}
fn handle_ready(Readiness(ready): &Readiness, req: Request) -> Response {
match *req.method() {
hyper::Method::GET | hyper::Method::HEAD => {
if ready.load(Ordering::Acquire) {
return hyper::Response::builder()
.status(hyper::StatusCode::OK)
.header(hyper::header::CONTENT_TYPE, "text/plain")
.body("ready\n".into())
.unwrap();
}
hyper::Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.header(hyper::header::CONTENT_TYPE, "text/plain")
.body("not ready\n".into())
.unwrap()
}
_ => hyper::Response::builder()
.status(hyper::StatusCode::METHOD_NOT_ALLOWED)
.header(hyper::header::ALLOW, "GET, HEAD")
.body(Body::default())
.unwrap(),
}
}