use bytes::Bytes;
use http_body_util::Full;
use hyper::{header, service::Service, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use prometheus::{Encoder, Registry, TextEncoder};
use std::{convert::Infallible, future::Future, net::SocketAddr, ops::Deref, pin::Pin};
use tokio::net::TcpListener;
use tracing::{info, trace};
#[cfg(feature = "internal_metrics")]
use prometheus::{
register_histogram_with_registry, register_int_counter_with_registry, register_int_gauge_with_registry, Histogram,
IntCounter, IntGauge,
};
#[cfg(feature = "internal_metrics")]
use std::convert::TryInto;
pub type RegistryFn = Box<dyn FnOnce(&Registry) -> Result<(), prometheus::Error>>;
pub struct Server {}
impl Server {
pub async fn run<S, F, R>(registry: R, addr: S, shutdown: F) -> Result<(), std::io::Error>
where
S: Into<SocketAddr>,
F: Future<Output = ()>,
R: Deref<Target = Registry> + Clone + Send + 'static,
{
let addr = addr.into();
#[cfg(feature = "internal_metrics")]
let durations = register_histogram_with_registry!(
"prometheus_exporter_request_duration_seconds",
"HTTP request durations in seconds",
registry
)
.unwrap();
#[cfg(feature = "internal_metrics")]
let requests = register_int_counter_with_registry!(
"prometheus_exporter_requests_total",
"HTTP requests received in metrics endpoint",
registry
)
.unwrap();
#[cfg(feature = "internal_metrics")]
let sizes = register_int_gauge_with_registry!(
"prometheus_exporter_response_size_bytes",
"HTTP response sizes in bytes",
registry
)
.unwrap();
info!("starting hyper server to serve metrics");
let service = MetricsService {
registry: registry.clone(),
#[cfg(feature = "internal_metrics")]
durations: durations.clone(),
#[cfg(feature = "internal_metrics")]
requests: requests.clone(),
#[cfg(feature = "internal_metrics")]
sizes: sizes.clone(),
};
let listener = TcpListener::bind(addr).await?;
let mut shutdown = core::pin::pin!(shutdown);
while let Some(conn) = tokio::select! {
_ = shutdown.as_mut() => None,
conn = listener.accept() => Some(conn),
} {
match conn {
Ok((tcp, _)) => {
let io = TokioIo::new(tcp);
let service_clone = service.clone();
tokio::task::spawn(async move {
use hyper::server::conn::http1;
let conn = http1::Builder::new().serve_connection(io, service_clone);
if let Err(e) = conn.await {
tracing::error!(?e, "error serving connection")
}
});
},
Err(e) => tracing::error!(?e, "error accepting new connection"),
}
}
#[cfg(feature = "internal_metrics")]
{
if let Err(e) = registry.unregister(Box::new(durations)) {
tracing::error!(?e, "could not unregister 'durations'");
};
if let Err(e) = registry.unregister(Box::new(requests)) {
tracing::error!(?e, "could not unregister 'requests'");
};
if let Err(e) = registry.unregister(Box::new(sizes)) {
tracing::error!(?e, "could not unregister 'sizes'");
};
}
Ok(())
}
}
#[cfg(feature = "internal_metrics")]
#[derive(Debug, Clone)]
struct MetricsService<R> {
registry: R,
durations: Histogram,
requests: IntCounter,
sizes: IntGauge,
}
#[cfg(not(feature = "internal_metrics"))]
#[derive(Debug, Clone)]
struct MetricsService<R> {
registry: R,
}
impl<R> Service<Request<hyper::body::Incoming>> for MetricsService<R>
where
R: Deref<Target = Registry> + Clone + Send + 'static,
{
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Response = Response<Full<Bytes>>;
fn call(&self, req: Request<hyper::body::Incoming>) -> Self::Future {
#[cfg(feature = "internal_metrics")]
let timer = self.durations.start_timer();
let (code, body) = if req.uri().path() == "/metrics" {
#[cfg(feature = "internal_metrics")]
self.requests.inc();
trace!("request");
let mf = self.registry.deref().gather();
let mut buffer = vec![];
let encoder = TextEncoder::new();
encoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
#[cfg(feature = "internal_metrics")]
if let Ok(size) = buffer.len().try_into() {
self.sizes.set(size);
}
(StatusCode::OK, Full::new(Bytes::from(buffer)))
} else {
trace!("wrong uri, return 404");
(StatusCode::NOT_FOUND, Full::new(Bytes::from("404 not found")))
};
let response = Response::builder()
.status(code)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.unwrap();
#[cfg(feature = "internal_metrics")]
timer.observe_duration();
Box::pin(async { Ok::<Response<http_body_util::Full<bytes::Bytes>>, Infallible>(response) })
}
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::Empty;
use hyper::Request;
use std::{sync::Arc, time::Duration};
use tokio::{net::TcpStream, sync::Notify};
#[tokio::test]
async fn test_create() {
let shutdown = Arc::new(Notify::new());
let registry = Arc::new(Registry::new());
let shutdown_clone = Arc::clone(&shutdown);
let r = tokio::spawn(async move {
Server::run(
Arc::clone(®istry),
SocketAddr::from(([0; 4], 6001)),
shutdown_clone.notified(),
)
.await
});
shutdown.notify_one();
r.await.expect("tokio error").expect("prometheus_hyper server error");
}
#[tokio::test]
async fn test_default() {
let shutdown = Arc::new(Notify::new());
let registry = prometheus::default_registry();
let shutdown_clone = Arc::clone(&shutdown);
let r = tokio::spawn(async move {
Server::run(registry, SocketAddr::from(([0; 4], 6002)), shutdown_clone.notified()).await
});
shutdown.notify_one();
r.await.expect("tokio error").expect("prometheus_hyper server error");
}
#[tokio::test]
async fn test_sample() {
let shutdown = Arc::new(Notify::new());
let registry = Arc::new(Registry::new());
let shutdown_clone = Arc::clone(&shutdown);
let r = tokio::spawn(async move {
Server::run(
Arc::clone(®istry),
SocketAddr::from(([0; 4], 6003)),
shutdown_clone.notified(),
)
.await
});
tokio::time::sleep(Duration::from_millis(500)).await;
let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6003))).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
let req = Request::builder()
.method("GET")
.uri("http://localhost:6003/metrics")
.body(Empty::<Bytes>::new())
.expect("request builder");
let res = sender.send_request(req).await.expect("couldn't reach server");
assert_eq!(res.status(), StatusCode::OK);
shutdown.notify_one();
r.await.expect("tokio error").expect("prometheus_hyper server error");
}
#[tokio::test]
async fn test_wrong_endpoint_sample() {
let shutdown = Arc::new(Notify::new());
let registry = Arc::new(Registry::new());
let shutdown_clone = Arc::clone(&shutdown);
let r = tokio::spawn(async move {
Server::run(
Arc::clone(®istry),
SocketAddr::from(([0; 4], 6004)),
shutdown_clone.notified(),
)
.await
});
tokio::time::sleep(Duration::from_millis(500)).await;
let stream = TcpStream::connect(SocketAddr::from(([0; 4], 6004))).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
let req = Request::builder()
.method("GET")
.uri("http://localhost:6004/foobar")
.body(Empty::<Bytes>::new())
.expect("request builder");
let res = sender.send_request(req).await.expect("couldn't reach server");
assert_eq!(res.status(), StatusCode::NOT_FOUND);
shutdown.notify_one();
r.await.expect("tokio error").expect("prometheus_hyper server error");
}
}