use std::net::SocketAddr;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode, header};
use hyper_util::rt::TokioIo;
use log::{error, warn};
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
static NOTFOUND: &[u8] = b"404: Not Found";
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum MetricsError {
#[error("I/O Error: {0}")]
IoError(#[from] std::io::Error),
#[error("Hyper error: {0}")]
HyperError(#[from] hyper::http::Error),
}
pub fn start_http_server(
port: u16,
metrics_encode_fn: fn() -> String,
) -> JoinHandle<Result<(), MetricsError>> {
let server = MetricsServer {
port,
metrics_encode_fn,
};
tokio::spawn(async move {
let result = server.run_server().await;
warn!("HTTP metrics server stopped: {:?}", result);
result
})
}
struct MetricsServer {
port: u16,
metrics_encode_fn: fn() -> String,
}
impl MetricsServer {
async fn run_server(&self) -> Result<(), MetricsError> {
let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, _) = listener.accept().await?;
self.handle_connection(stream).await;
}
}
async fn handle_connection(&self, stream: tokio::net::TcpStream) {
let io = TokioIo::new(stream);
let service = service_fn(|req| self.routes(req));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
error!("Failed to serve metrics connection: {:?}", err);
}
}
async fn routes(&self, req: Request<Incoming>) -> Result<Response<BoxBody>, MetricsError> {
match (req.method(), req.uri().path()) {
(&Method::GET, "/metrics") => self.get_metrics(),
_ => not_found(),
}
}
fn get_metrics(&self) -> Result<Response<BoxBody>, MetricsError> {
let body = (self.metrics_encode_fn)();
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain")
.body(full(body))?)
}
}
fn not_found() -> Result<Response<BoxBody>, MetricsError> {
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(full(NOTFOUND))?)
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::Empty;
use hyper::Uri;
use hyper::body::Body;
use hyper::client::conn;
use hyper::client::conn::http1::{Connection, SendRequest};
use prometheus::{IntCounter, register_int_counter};
use serial_test::serial;
use std::sync::OnceLock;
use tokio::net::TcpStream;
const PORT: u16 = 9090;
pub fn metrics_to_string() -> String {
let encoder = prometheus::TextEncoder::new();
encoder
.encode_to_string(&prometheus::gather())
.unwrap_or_default()
}
pub fn high_five_counter() -> &'static IntCounter {
static CONSUMED_MESSAGES: OnceLock<IntCounter> = OnceLock::new();
CONSUMED_MESSAGES.get_or_init(|| {
register_int_counter!("highfives", "Number of highfives given").unwrap()
})
}
async fn create_client(
url: &Uri,
) -> (
SendRequest<Empty<Bytes>>,
Connection<TokioIo<TcpStream>, Empty<Bytes>>,
) {
let host = url.host().expect("URI has no host");
let port = url.port_u16().unwrap_or(PORT);
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
conn::http1::handshake(io).await.unwrap()
}
fn to_get_req(url: &Uri) -> Request<Empty<Bytes>> {
Request::builder()
.uri(url)
.method(Method::GET)
.header(header::HOST, url.authority().unwrap().as_str())
.body(Empty::new())
.unwrap()
}
#[tokio::test]
async fn test_http_metric_response() {
high_five_counter().inc();
let server = MetricsServer {
port: PORT,
metrics_encode_fn: metrics_to_string,
};
let response = server.get_metrics().expect("failed to get metrics");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/plain"
);
assert!(response.body().size_hint().exact().unwrap() > 0);
}
#[tokio::test]
#[serial(port_usage)]
async fn test_start_http_server() {
let server = start_http_server(PORT, metrics_to_string);
high_five_counter().inc();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let url: Uri = format!("http://localhost:{PORT}/metrics").parse().unwrap();
let (mut request_sender, connection) = create_client(&url).await;
tokio::task::spawn(async move {
if let Err(err) = connection.await {
error!("Connection failed: {:?}", err);
}
});
let request = to_get_req(&url);
let response = request_sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/plain"
);
let buf = response.collect().await.unwrap().to_bytes();
let res = String::from_utf8(buf.to_vec()).unwrap();
assert!(!res.is_empty());
server.abort();
}
#[tokio::test]
#[serial(port_usage)]
async fn test_unknown_path() {
let server = start_http_server(PORT, metrics_to_string);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let url: Uri = format!("http://localhost:{PORT}").parse().unwrap();
let (mut request_sender, connection) = create_client(&url).await;
tokio::task::spawn(async move {
if let Err(err) = connection.await {
error!("Connection failed: {:?}", err);
}
});
let request = to_get_req(&url);
let response = request_sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let buf = response.collect().await.unwrap().to_bytes();
let res = String::from_utf8(buf.to_vec()).unwrap();
assert_eq!(res, String::from_utf8_lossy(NOTFOUND));
server.abort();
}
}